From 7b062be8fabf3210368e2de73b513ab09cef5bf6 Mon Sep 17 00:00:00 2001 From: Cuong Duong Date: Tue, 10 Oct 2023 06:49:11 +1100 Subject: [PATCH] Separate preprocessing step from fit() method for easier debugging (#2505) --- python/prophet/forecaster.py | 132 ++++++++++++++++++++--------------- python/prophet/models.py | 33 ++++++++- 2 files changed, 109 insertions(+), 56 deletions(-) diff --git a/python/prophet/forecaster.py b/python/prophet/forecaster.py index 4f1434d..47e368f 100644 --- a/python/prophet/forecaster.py +++ b/python/prophet/forecaster.py @@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function +import dataclasses import logging from collections import OrderedDict, defaultdict from copy import deepcopy @@ -17,7 +18,7 @@ import pandas as pd from numpy.typing import NDArray from prophet.make_holidays import get_holiday_names, make_holidays_df -from prophet.models import StanBackendEnum +from prophet.models import StanBackendEnum, ModelInputData, ModelParams, TrendIndicator from prophet.plot import (plot, plot_components) logger = logging.getLogger('prophet') @@ -1114,6 +1115,76 @@ class Prophet(object): m = df['y_scaled'].mean() return k, m + def preprocess(self, df: pd.DataFrame, **kwargs) -> ModelInputData: + """ + Reformats historical data, standardizes y and extra regressors, sets seasonalities and changepoints. + + Saves the preprocessed data to the instantiated object, and also returns the relevant components + as a ModelInputData object. + """ + if ('ds' not in df) or ('y' not in df): + raise ValueError( + 'Dataframe must have columns "ds" and "y" with the dates and ' + 'values respectively.' + ) + history = df[df['y'].notnull()].copy() + if history.shape[0] < 2: + raise ValueError('Dataframe has less than 2 non-NaN rows.') + self.history_dates = pd.to_datetime(pd.Series(history['ds'].unique(), name='ds')).sort_values() + + self.history = self.setup_dataframe(history, initialize_scales=True) + self.set_auto_seasonalities() + seasonal_features, prior_scales, component_cols, modes = ( + self.make_all_seasonality_features(self.history)) + self.train_component_cols = component_cols + self.component_modes = modes + self.fit_kwargs = deepcopy(kwargs) + + self.set_changepoints() + + if self.growth in ['linear', 'flat']: + cap = np.zeros(self.history.shape[0]) + else: + cap = self.history['cap_scaled'] + + return ModelInputData( + T=self.history.shape[0], + S=len(self.changepoints_t), + K=seasonal_features.shape[1], + tau=self.changepoint_prior_scale, + trend_indicator=TrendIndicator[self.growth.upper()].value, + y=self.history['y_scaled'], + t=self.history['t'], + t_change=self.changepoints_t, + X=seasonal_features, + sigmas=prior_scales, + s_a=component_cols['additive_terms'], + s_m=component_cols['multiplicative_terms'], + cap=cap, + ) + + def calculate_initial_params(self, num_total_regressors: int) -> ModelParams: + """ + Calculates initial parameters for the model based on the preprocessed history. + + Parameters + ---------- + num_total_regressors: the count of seasonality fourier components plus holidays plus extra regressors. + """ + if self.growth == 'linear': + k, m = self.linear_growth_init(self.history) + elif self.growth == 'flat': + k, m = self.flat_growth_init(self.history) + elif self.growth == 'logistic': + k, m = self.logistic_growth_init(self.history) + return ModelParams( + k=k, + m=m, + delta=np.zeros_like(self.changepoints_t), + beta=np.zeros(num_total_regressors), + sigma_obs=1.0, + ) + def fit(self, df, **kwargs): """Fit the Prophet model. @@ -1142,63 +1213,14 @@ class Prophet(object): if self.history is not None: raise Exception('Prophet object can only be fit once. ' 'Instantiate a new object.') - if ('ds' not in df) or ('y' not in df): - raise ValueError( - 'Dataframe must have columns "ds" and "y" with the dates and ' - 'values respectively.' - ) - history = df[df['y'].notnull()].copy() - if history.shape[0] < 2: - raise ValueError('Dataframe has less than 2 non-NaN rows.') - self.history_dates = pd.to_datetime(pd.Series(df['ds'].unique(), name='ds')).sort_values() - history = self.setup_dataframe(history, initialize_scales=True) - self.history = history - self.set_auto_seasonalities() - seasonal_features, prior_scales, component_cols, modes = ( - self.make_all_seasonality_features(history)) - self.train_component_cols = component_cols - self.component_modes = modes - self.fit_kwargs = deepcopy(kwargs) + model_inputs = self.preprocess(df, **kwargs) + initial_params = self.calculate_initial_params(model_inputs.K) - self.set_changepoints() + dat = dataclasses.asdict(model_inputs) + stan_init = dataclasses.asdict(initial_params) - trend_indicator = {'linear': 0, 'logistic': 1, 'flat': 2} - - dat = { - 'T': history.shape[0], - 'K': seasonal_features.shape[1], - 'S': len(self.changepoints_t), - 'y': history['y_scaled'], - 't': history['t'], - 't_change': self.changepoints_t, - 'X': seasonal_features, - 'sigmas': prior_scales, - 'tau': self.changepoint_prior_scale, - 'trend_indicator': trend_indicator[self.growth], - 's_a': component_cols['additive_terms'], - 's_m': component_cols['multiplicative_terms'], - } - - if self.growth == 'linear': - dat['cap'] = np.zeros(self.history.shape[0]) - kinit = self.linear_growth_init(history) - elif self.growth == 'flat': - dat['cap'] = np.zeros(self.history.shape[0]) - kinit = self.flat_growth_init(history) - else: - dat['cap'] = history['cap_scaled'] - kinit = self.logistic_growth_init(history) - - stan_init = { - 'k': kinit[0], - 'm': kinit[1], - 'delta': np.zeros(len(self.changepoints_t)), - 'beta': np.zeros(seasonal_features.shape[1]), - 'sigma_obs': 1, - } - - if history['y'].min() == history['y'].max() and \ + if self.history['y'].min() == self.history['y'].max() and \ (self.growth == 'linear' or self.growth == 'flat'): self.params = stan_init self.params['sigma_obs'] = 1e-9 diff --git a/python/prophet/models.py b/python/prophet/models.py index 012f90a..30293cb 100644 --- a/python/prophet/models.py +++ b/python/prophet/models.py @@ -6,7 +6,8 @@ from __future__ import absolute_import, division, print_function from abc import abstractmethod, ABC -from typing import Tuple +from dataclasses import dataclass +from typing import Sequence, Tuple from collections import OrderedDict from enum import Enum import importlib_resources @@ -17,6 +18,36 @@ logger = logging.getLogger('prophet.models') PLATFORM = "win" if platform.platform().startswith("Win") else "unix" +class TrendIndicator(Enum): + LINEAR = 0 + LOGISTIC = 1 + FLAT = 2 + +@dataclass +class ModelInputData: + T: int + S: int + K: int + tau: float + trend_indicator: int + y: Sequence[float] # length T + t: Sequence[float] # length T + cap: Sequence[float] # length T + t_change: Sequence[float] # length S + s_a: Sequence[int] # length K + s_m: Sequence[int] # length K + X: Sequence[Sequence[float]] # shape (T, K) + sigmas: Sequence[float] # length K + +@dataclass +class ModelParams: + k: float + m: float + delta: Sequence[float] # length S + beta: Sequence[float] # length K + sigma_obs: float + + class IStanBackend(ABC): def __init__(self): self.model = self.load_model()