mgplot.postcovid_plot
Plot the linear pre-COVID trajectory against the current data.
1"""Plot the linear pre-COVID trajectory against the current data.""" 2 3from typing import Literal, NotRequired, Unpack, cast 4 5from matplotlib.axes import Axes 6from numpy import array, polyfit 7from pandas import DataFrame, Period, PeriodIndex, Series, period_range 8 9from mgplot.keyword_checking import ( 10 report_kwargs, 11 validate_kwargs, 12) 13from mgplot.line_plot import LineKwargs, line_plot 14from mgplot.settings import DataT, get_setting 15from mgplot.utilities import check_clean_timeseries 16 17# --- constants 18ME = "postcovid_plot" 19MIN_REGRESSION_POINTS = 10 # minimum number of points for a useful linear regression 20 21# Default regression periods by frequency 22DEFAULT_PERIODS = { 23 "Q": {"start": "2014Q4", "end": "2019Q4"}, 24 "M": {"start": "2015-01", "end": "2020-01"}, 25 "D": {"start": "2015-01-01", "end": "2020-01-01"}, 26} 27 28 29class PostcovidKwargs(LineKwargs): 30 """Keyword arguments for the post-COVID plot.""" 31 32 start_r: NotRequired[Period] # start of regression period 33 end_r: NotRequired[Period] # end of regression period 34 35 36# --- functions 37def get_projection(source: Series, to_period: Period) -> Series: 38 """Create a linear projection based on pre-COVID data. 39 40 Args: 41 source: Series - the original series with a PeriodIndex 42 Assume the index is a PeriodIndex, that is unique and monotonic increasing. 43 Assume there may be gaps in the source series (either missing or NaNs) 44 And that it starts from when the regression should start. 45 to_period: Period - the period to which the projection should extend. 46 47 Returns: 48 Series: A pandas Series with linear projection values using the same index as original. 49 Returns an empty Series if it fails to create a projection. 50 51 Raises: 52 ValueError: If to_period is not within the original series index range. 53 54 """ 55 # --- initial validation 56 if not isinstance(source.index, PeriodIndex): 57 raise TypeError("Source index must be a PeriodIndex") 58 if source.empty or not source.index.is_monotonic_increasing or not source.index.is_unique: 59 print("Source series must be non-empty, uniquely indexed, and a monotonic increasing index.") 60 return Series(dtype=float) # return empty series if validation fails 61 62 # --- Drop any missing data and establish the input data for regression 63 source_no_nan = source.dropna() 64 input_series = source_no_nan[source_no_nan.index <= to_period] 65 66 # --- further validation 67 if input_series.empty or len(input_series) < MIN_REGRESSION_POINTS: 68 print("Insufficient data points for regression.") 69 return Series(dtype=float) # return empty series if no data for regression 70 71 # --- Establish the simple linear regression model 72 input_index = input_series.index 73 x_cause = array([p.ordinal for p in input_index if p <= to_period]) 74 y_effect = input_series.to_numpy() 75 slope, intercept = polyfit(x_cause, y_effect, 1) 76 77 # --- use the regression model to create an out-of-sample projection 78 x_complete = array([p.ordinal for p in source.index]) 79 projection = Series((x_complete * slope) + intercept, index=source.index) 80 81 # --- ensure the projection covers any date gaps in the PeriodIndex 82 source_index = source.index 83 return projection.reindex(period_range(start=source_index[0], end=source_index[-1])).interpolate( 84 method="linear" 85 ) 86 87 88def regression_period(data: Series, **kwargs: Unpack[PostcovidKwargs]) -> tuple[Period, Period, bool]: 89 """Establish the regression period. 90 91 Args: 92 data: Series - the original time series data. 93 **kwargs: Additional keyword arguments. 94 95 Returns: 96 A tuple containing the start and end periods for regression, 97 and a boolean indicating if the period is robust. 98 99 Raises: 100 TypeError: If the series index is not a PeriodIndex. 101 ValueError: If the series index does not have a D, M, or Q frequency 102 103 """ 104 # --- check that the series index is a PeriodIndex with a valid frequency 105 if not isinstance(data.index, PeriodIndex): 106 raise TypeError("The series index must be a PeriodIndex") 107 freq_str = data.index.freqstr 108 freq_key = freq_str[0] 109 if not freq_str or freq_key not in ("Q", "M", "D"): 110 raise ValueError("The series index must have a D, M or Q frequency") 111 112 # --- set the default regression period, use user provided periods if specified 113 default_periods = DEFAULT_PERIODS[freq_key] 114 start_regression = Period(default_periods["start"], freq=freq_str) 115 end_regression = Period(default_periods["end"], freq=freq_str) 116 117 user_start = kwargs.pop("start_r", None) 118 user_end = kwargs.pop("end_r", None) 119 start_r = Period(user_start, freq=freq_str) if user_start else start_regression 120 end_r = Period(user_end, freq=freq_str) if user_end else end_regression 121 122 # --- Validate the regression period 123 robust = True 124 if start_r >= end_r: 125 print(f"Invalid regression period: {start_r=}, {end_r=}") 126 robust = False 127 128 return start_r, end_r, robust 129 130 131def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes: 132 """Plot a series with a PeriodIndex, including a post-COVID projection. 133 134 Args: 135 data: Series - the series to be plotted. 136 kwargs: PostcovidKwargs - plotting arguments. 137 138 Raises: 139 TypeError if series is not a pandas Series 140 TypeError if series does not have a PeriodIndex 141 ValueError if series does not have a D, M or Q frequency 142 ValueError if regression start is after regression end 143 144 """ 145 146 # --- failure 147 def failure() -> Axes: 148 print("postcovid_plot(): plotting the raw data only.") 149 remove: list[Literal["plot_from", "start_r", "end_r"]] = ["plot_from", "start_r", "end_r"] 150 for key in remove: 151 kwargs.pop(key, None) 152 return line_plot( 153 data, 154 **cast("LineKwargs", kwargs), 155 ) 156 157 # --- check the kwargs 158 report_kwargs(caller=ME, **kwargs) 159 validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs) 160 161 # --- check the data 162 data = check_clean_timeseries(data, ME) 163 if not isinstance(data, Series): 164 raise TypeError("The series argument must be a pandas Series") 165 166 # --- rely on line_plot() to validate kwargs, but remove any that are not relevant 167 if "plot_from" in kwargs: 168 print("Warning: the 'plot_from' argument is ignored in postcovid_plot().") 169 kwargs.pop("plot_from", None) 170 171 # --- set the regression period 172 start_r, end_r, robust = regression_period(data, **kwargs) 173 kwargs.pop("start_r", None) # remove from kwargs to avoid confusion 174 kwargs.pop("end_r", None) # remove from kwargs to avoid confusion 175 if not robust: 176 return failure() 177 178 # --- combine data and projection 179 if start_r < data.dropna().index.min(): 180 print(f"Caution: Regression start period pre-dates the series index: {start_r=}") 181 recent_data = data[data.index >= start_r].copy() 182 recent_data.name = "Series" 183 projection_data = get_projection(recent_data, end_r) 184 if projection_data.empty: 185 return failure() 186 projection_data.name = "Pre-COVID projection" 187 188 # --- Create DataFrame with proper column alignment 189 combined_data = DataFrame( 190 { 191 projection_data.name: projection_data, 192 recent_data.name: recent_data, 193 } 194 ) 195 196 # --- activate plot settings 197 kwargs["width"] = kwargs.pop( 198 "width", 199 (get_setting("line_normal"), get_setting("line_wide")), 200 ) # series line is thicker than projection 201 kwargs["style"] = kwargs.pop("style", ("--", "-")) # dashed regression line 202 kwargs["label_series"] = kwargs.pop("label_series", True) 203 kwargs["annotate"] = kwargs.pop("annotate", (False, True)) # annotate series only 204 kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000")) 205 kwargs["dropna"] = kwargs.pop("dropna", False) # drop NaN values 206 207 return line_plot( 208 combined_data, 209 **cast("LineKwargs", kwargs), 210 ) 211 212 213if __name__ == "__main__": 214 215 def test_make_projection() -> None: 216 """Test the get_projection function.""" 217 n = 30 218 periods = period_range(start="2015-Q1", periods=n, freq="Q") 219 series = Series( 220 [i + (i % 3) for i in range(n)], # simple increasing series with some noise 221 index=periods, 222 ) 223 proj = get_projection(series, Period("2019-Q4", freq="Q")) 224 print( 225 DataFrame( 226 { 227 "Input": series, 228 "Projection": proj, 229 } 230 ) 231 ) 232 233 test_make_projection()
30class PostcovidKwargs(LineKwargs): 31 """Keyword arguments for the post-COVID plot.""" 32 33 start_r: NotRequired[Period] # start of regression period 34 end_r: NotRequired[Period] # end of regression period
Keyword arguments for the post-COVID plot.
38def get_projection(source: Series, to_period: Period) -> Series: 39 """Create a linear projection based on pre-COVID data. 40 41 Args: 42 source: Series - the original series with a PeriodIndex 43 Assume the index is a PeriodIndex, that is unique and monotonic increasing. 44 Assume there may be gaps in the source series (either missing or NaNs) 45 And that it starts from when the regression should start. 46 to_period: Period - the period to which the projection should extend. 47 48 Returns: 49 Series: A pandas Series with linear projection values using the same index as original. 50 Returns an empty Series if it fails to create a projection. 51 52 Raises: 53 ValueError: If to_period is not within the original series index range. 54 55 """ 56 # --- initial validation 57 if not isinstance(source.index, PeriodIndex): 58 raise TypeError("Source index must be a PeriodIndex") 59 if source.empty or not source.index.is_monotonic_increasing or not source.index.is_unique: 60 print("Source series must be non-empty, uniquely indexed, and a monotonic increasing index.") 61 return Series(dtype=float) # return empty series if validation fails 62 63 # --- Drop any missing data and establish the input data for regression 64 source_no_nan = source.dropna() 65 input_series = source_no_nan[source_no_nan.index <= to_period] 66 67 # --- further validation 68 if input_series.empty or len(input_series) < MIN_REGRESSION_POINTS: 69 print("Insufficient data points for regression.") 70 return Series(dtype=float) # return empty series if no data for regression 71 72 # --- Establish the simple linear regression model 73 input_index = input_series.index 74 x_cause = array([p.ordinal for p in input_index if p <= to_period]) 75 y_effect = input_series.to_numpy() 76 slope, intercept = polyfit(x_cause, y_effect, 1) 77 78 # --- use the regression model to create an out-of-sample projection 79 x_complete = array([p.ordinal for p in source.index]) 80 projection = Series((x_complete * slope) + intercept, index=source.index) 81 82 # --- ensure the projection covers any date gaps in the PeriodIndex 83 source_index = source.index 84 return projection.reindex(period_range(start=source_index[0], end=source_index[-1])).interpolate( 85 method="linear" 86 )
Create a linear projection based on pre-COVID data.
Args: source: Series - the original series with a PeriodIndex Assume the index is a PeriodIndex, that is unique and monotonic increasing. Assume there may be gaps in the source series (either missing or NaNs) And that it starts from when the regression should start. to_period: Period - the period to which the projection should extend.
Returns: Series: A pandas Series with linear projection values using the same index as original. Returns an empty Series if it fails to create a projection.
Raises: ValueError: If to_period is not within the original series index range.
89def regression_period(data: Series, **kwargs: Unpack[PostcovidKwargs]) -> tuple[Period, Period, bool]: 90 """Establish the regression period. 91 92 Args: 93 data: Series - the original time series data. 94 **kwargs: Additional keyword arguments. 95 96 Returns: 97 A tuple containing the start and end periods for regression, 98 and a boolean indicating if the period is robust. 99 100 Raises: 101 TypeError: If the series index is not a PeriodIndex. 102 ValueError: If the series index does not have a D, M, or Q frequency 103 104 """ 105 # --- check that the series index is a PeriodIndex with a valid frequency 106 if not isinstance(data.index, PeriodIndex): 107 raise TypeError("The series index must be a PeriodIndex") 108 freq_str = data.index.freqstr 109 freq_key = freq_str[0] 110 if not freq_str or freq_key not in ("Q", "M", "D"): 111 raise ValueError("The series index must have a D, M or Q frequency") 112 113 # --- set the default regression period, use user provided periods if specified 114 default_periods = DEFAULT_PERIODS[freq_key] 115 start_regression = Period(default_periods["start"], freq=freq_str) 116 end_regression = Period(default_periods["end"], freq=freq_str) 117 118 user_start = kwargs.pop("start_r", None) 119 user_end = kwargs.pop("end_r", None) 120 start_r = Period(user_start, freq=freq_str) if user_start else start_regression 121 end_r = Period(user_end, freq=freq_str) if user_end else end_regression 122 123 # --- Validate the regression period 124 robust = True 125 if start_r >= end_r: 126 print(f"Invalid regression period: {start_r=}, {end_r=}") 127 robust = False 128 129 return start_r, end_r, robust
Establish the regression period.
Args: data: Series - the original time series data. **kwargs: Additional keyword arguments.
Returns: A tuple containing the start and end periods for regression, and a boolean indicating if the period is robust.
Raises: TypeError: If the series index is not a PeriodIndex. ValueError: If the series index does not have a D, M, or Q frequency
132def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes: 133 """Plot a series with a PeriodIndex, including a post-COVID projection. 134 135 Args: 136 data: Series - the series to be plotted. 137 kwargs: PostcovidKwargs - plotting arguments. 138 139 Raises: 140 TypeError if series is not a pandas Series 141 TypeError if series does not have a PeriodIndex 142 ValueError if series does not have a D, M or Q frequency 143 ValueError if regression start is after regression end 144 145 """ 146 147 # --- failure 148 def failure() -> Axes: 149 print("postcovid_plot(): plotting the raw data only.") 150 remove: list[Literal["plot_from", "start_r", "end_r"]] = ["plot_from", "start_r", "end_r"] 151 for key in remove: 152 kwargs.pop(key, None) 153 return line_plot( 154 data, 155 **cast("LineKwargs", kwargs), 156 ) 157 158 # --- check the kwargs 159 report_kwargs(caller=ME, **kwargs) 160 validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs) 161 162 # --- check the data 163 data = check_clean_timeseries(data, ME) 164 if not isinstance(data, Series): 165 raise TypeError("The series argument must be a pandas Series") 166 167 # --- rely on line_plot() to validate kwargs, but remove any that are not relevant 168 if "plot_from" in kwargs: 169 print("Warning: the 'plot_from' argument is ignored in postcovid_plot().") 170 kwargs.pop("plot_from", None) 171 172 # --- set the regression period 173 start_r, end_r, robust = regression_period(data, **kwargs) 174 kwargs.pop("start_r", None) # remove from kwargs to avoid confusion 175 kwargs.pop("end_r", None) # remove from kwargs to avoid confusion 176 if not robust: 177 return failure() 178 179 # --- combine data and projection 180 if start_r < data.dropna().index.min(): 181 print(f"Caution: Regression start period pre-dates the series index: {start_r=}") 182 recent_data = data[data.index >= start_r].copy() 183 recent_data.name = "Series" 184 projection_data = get_projection(recent_data, end_r) 185 if projection_data.empty: 186 return failure() 187 projection_data.name = "Pre-COVID projection" 188 189 # --- Create DataFrame with proper column alignment 190 combined_data = DataFrame( 191 { 192 projection_data.name: projection_data, 193 recent_data.name: recent_data, 194 } 195 ) 196 197 # --- activate plot settings 198 kwargs["width"] = kwargs.pop( 199 "width", 200 (get_setting("line_normal"), get_setting("line_wide")), 201 ) # series line is thicker than projection 202 kwargs["style"] = kwargs.pop("style", ("--", "-")) # dashed regression line 203 kwargs["label_series"] = kwargs.pop("label_series", True) 204 kwargs["annotate"] = kwargs.pop("annotate", (False, True)) # annotate series only 205 kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000")) 206 kwargs["dropna"] = kwargs.pop("dropna", False) # drop NaN values 207 208 return line_plot( 209 combined_data, 210 **cast("LineKwargs", kwargs), 211 )
Plot a series with a PeriodIndex, including a post-COVID projection.
Args: data: Series - the series to be plotted. kwargs: PostcovidKwargs - plotting arguments.
Raises: TypeError if series is not a pandas Series TypeError if series does not have a PeriodIndex ValueError if series does not have a D, M or Q frequency ValueError if regression start is after regression end