Separate preprocessing step from fit() method for easier debugging (#2505)

This commit is contained in:
Cuong Duong 2023-10-10 06:49:11 +11:00 committed by GitHub
parent 279d8d6a30
commit 7b062be8fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 56 deletions

View file

@ -6,6 +6,7 @@
from __future__ import absolute_import, division, print_function
import dataclasses
import logging
from collections import OrderedDict, defaultdict
from copy import deepcopy
@ -17,7 +18,7 @@ import pandas as pd
from numpy.typing import NDArray
from prophet.make_holidays import get_holiday_names, make_holidays_df
from prophet.models import StanBackendEnum
from prophet.models import StanBackendEnum, ModelInputData, ModelParams, TrendIndicator
from prophet.plot import (plot, plot_components)
logger = logging.getLogger('prophet')
@ -1114,6 +1115,76 @@ class Prophet(object):
m = df['y_scaled'].mean()
return k, m
def preprocess(self, df: pd.DataFrame, **kwargs) -> ModelInputData:
"""
Reformats historical data, standardizes y and extra regressors, sets seasonalities and changepoints.
Saves the preprocessed data to the instantiated object, and also returns the relevant components
as a ModelInputData object.
"""
if ('ds' not in df) or ('y' not in df):
raise ValueError(
'Dataframe must have columns "ds" and "y" with the dates and '
'values respectively.'
)
history = df[df['y'].notnull()].copy()
if history.shape[0] < 2:
raise ValueError('Dataframe has less than 2 non-NaN rows.')
self.history_dates = pd.to_datetime(pd.Series(history['ds'].unique(), name='ds')).sort_values()
self.history = self.setup_dataframe(history, initialize_scales=True)
self.set_auto_seasonalities()
seasonal_features, prior_scales, component_cols, modes = (
self.make_all_seasonality_features(self.history))
self.train_component_cols = component_cols
self.component_modes = modes
self.fit_kwargs = deepcopy(kwargs)
self.set_changepoints()
if self.growth in ['linear', 'flat']:
cap = np.zeros(self.history.shape[0])
else:
cap = self.history['cap_scaled']
return ModelInputData(
T=self.history.shape[0],
S=len(self.changepoints_t),
K=seasonal_features.shape[1],
tau=self.changepoint_prior_scale,
trend_indicator=TrendIndicator[self.growth.upper()].value,
y=self.history['y_scaled'],
t=self.history['t'],
t_change=self.changepoints_t,
X=seasonal_features,
sigmas=prior_scales,
s_a=component_cols['additive_terms'],
s_m=component_cols['multiplicative_terms'],
cap=cap,
)
def calculate_initial_params(self, num_total_regressors: int) -> ModelParams:
"""
Calculates initial parameters for the model based on the preprocessed history.
Parameters
----------
num_total_regressors: the count of seasonality fourier components plus holidays plus extra regressors.
"""
if self.growth == 'linear':
k, m = self.linear_growth_init(self.history)
elif self.growth == 'flat':
k, m = self.flat_growth_init(self.history)
elif self.growth == 'logistic':
k, m = self.logistic_growth_init(self.history)
return ModelParams(
k=k,
m=m,
delta=np.zeros_like(self.changepoints_t),
beta=np.zeros(num_total_regressors),
sigma_obs=1.0,
)
def fit(self, df, **kwargs):
"""Fit the Prophet model.
@ -1142,63 +1213,14 @@ class Prophet(object):
if self.history is not None:
raise Exception('Prophet object can only be fit once. '
'Instantiate a new object.')
if ('ds' not in df) or ('y' not in df):
raise ValueError(
'Dataframe must have columns "ds" and "y" with the dates and '
'values respectively.'
)
history = df[df['y'].notnull()].copy()
if history.shape[0] < 2:
raise ValueError('Dataframe has less than 2 non-NaN rows.')
self.history_dates = pd.to_datetime(pd.Series(df['ds'].unique(), name='ds')).sort_values()
history = self.setup_dataframe(history, initialize_scales=True)
self.history = history
self.set_auto_seasonalities()
seasonal_features, prior_scales, component_cols, modes = (
self.make_all_seasonality_features(history))
self.train_component_cols = component_cols
self.component_modes = modes
self.fit_kwargs = deepcopy(kwargs)
model_inputs = self.preprocess(df, **kwargs)
initial_params = self.calculate_initial_params(model_inputs.K)
self.set_changepoints()
dat = dataclasses.asdict(model_inputs)
stan_init = dataclasses.asdict(initial_params)
trend_indicator = {'linear': 0, 'logistic': 1, 'flat': 2}
dat = {
'T': history.shape[0],
'K': seasonal_features.shape[1],
'S': len(self.changepoints_t),
'y': history['y_scaled'],
't': history['t'],
't_change': self.changepoints_t,
'X': seasonal_features,
'sigmas': prior_scales,
'tau': self.changepoint_prior_scale,
'trend_indicator': trend_indicator[self.growth],
's_a': component_cols['additive_terms'],
's_m': component_cols['multiplicative_terms'],
}
if self.growth == 'linear':
dat['cap'] = np.zeros(self.history.shape[0])
kinit = self.linear_growth_init(history)
elif self.growth == 'flat':
dat['cap'] = np.zeros(self.history.shape[0])
kinit = self.flat_growth_init(history)
else:
dat['cap'] = history['cap_scaled']
kinit = self.logistic_growth_init(history)
stan_init = {
'k': kinit[0],
'm': kinit[1],
'delta': np.zeros(len(self.changepoints_t)),
'beta': np.zeros(seasonal_features.shape[1]),
'sigma_obs': 1,
}
if history['y'].min() == history['y'].max() and \
if self.history['y'].min() == self.history['y'].max() and \
(self.growth == 'linear' or self.growth == 'flat'):
self.params = stan_init
self.params['sigma_obs'] = 1e-9

View file

@ -6,7 +6,8 @@
from __future__ import absolute_import, division, print_function
from abc import abstractmethod, ABC
from typing import Tuple
from dataclasses import dataclass
from typing import Sequence, Tuple
from collections import OrderedDict
from enum import Enum
import importlib_resources
@ -17,6 +18,36 @@ logger = logging.getLogger('prophet.models')
PLATFORM = "win" if platform.platform().startswith("Win") else "unix"
class TrendIndicator(Enum):
LINEAR = 0
LOGISTIC = 1
FLAT = 2
@dataclass
class ModelInputData:
T: int
S: int
K: int
tau: float
trend_indicator: int
y: Sequence[float] # length T
t: Sequence[float] # length T
cap: Sequence[float] # length T
t_change: Sequence[float] # length S
s_a: Sequence[int] # length K
s_m: Sequence[int] # length K
X: Sequence[Sequence[float]] # shape (T, K)
sigmas: Sequence[float] # length K
@dataclass
class ModelParams:
k: float
m: float
delta: Sequence[float] # length S
beta: Sequence[float] # length K
sigma_obs: float
class IStanBackend(ABC):
def __init__(self):
self.model = self.load_model()