mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Separate preprocessing step from fit() method for easier debugging (#2505)
This commit is contained in:
parent
279d8d6a30
commit
7b062be8fa
2 changed files with 109 additions and 56 deletions
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
|
|
@ -17,7 +18,7 @@ import pandas as pd
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from prophet.make_holidays import get_holiday_names, make_holidays_df
|
||||
from prophet.models import StanBackendEnum
|
||||
from prophet.models import StanBackendEnum, ModelInputData, ModelParams, TrendIndicator
|
||||
from prophet.plot import (plot, plot_components)
|
||||
|
||||
logger = logging.getLogger('prophet')
|
||||
|
|
@ -1114,6 +1115,76 @@ class Prophet(object):
|
|||
m = df['y_scaled'].mean()
|
||||
return k, m
|
||||
|
||||
def preprocess(self, df: pd.DataFrame, **kwargs) -> ModelInputData:
|
||||
"""
|
||||
Reformats historical data, standardizes y and extra regressors, sets seasonalities and changepoints.
|
||||
|
||||
Saves the preprocessed data to the instantiated object, and also returns the relevant components
|
||||
as a ModelInputData object.
|
||||
"""
|
||||
if ('ds' not in df) or ('y' not in df):
|
||||
raise ValueError(
|
||||
'Dataframe must have columns "ds" and "y" with the dates and '
|
||||
'values respectively.'
|
||||
)
|
||||
history = df[df['y'].notnull()].copy()
|
||||
if history.shape[0] < 2:
|
||||
raise ValueError('Dataframe has less than 2 non-NaN rows.')
|
||||
self.history_dates = pd.to_datetime(pd.Series(history['ds'].unique(), name='ds')).sort_values()
|
||||
|
||||
self.history = self.setup_dataframe(history, initialize_scales=True)
|
||||
self.set_auto_seasonalities()
|
||||
seasonal_features, prior_scales, component_cols, modes = (
|
||||
self.make_all_seasonality_features(self.history))
|
||||
self.train_component_cols = component_cols
|
||||
self.component_modes = modes
|
||||
self.fit_kwargs = deepcopy(kwargs)
|
||||
|
||||
self.set_changepoints()
|
||||
|
||||
if self.growth in ['linear', 'flat']:
|
||||
cap = np.zeros(self.history.shape[0])
|
||||
else:
|
||||
cap = self.history['cap_scaled']
|
||||
|
||||
return ModelInputData(
|
||||
T=self.history.shape[0],
|
||||
S=len(self.changepoints_t),
|
||||
K=seasonal_features.shape[1],
|
||||
tau=self.changepoint_prior_scale,
|
||||
trend_indicator=TrendIndicator[self.growth.upper()].value,
|
||||
y=self.history['y_scaled'],
|
||||
t=self.history['t'],
|
||||
t_change=self.changepoints_t,
|
||||
X=seasonal_features,
|
||||
sigmas=prior_scales,
|
||||
s_a=component_cols['additive_terms'],
|
||||
s_m=component_cols['multiplicative_terms'],
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
def calculate_initial_params(self, num_total_regressors: int) -> ModelParams:
|
||||
"""
|
||||
Calculates initial parameters for the model based on the preprocessed history.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_total_regressors: the count of seasonality fourier components plus holidays plus extra regressors.
|
||||
"""
|
||||
if self.growth == 'linear':
|
||||
k, m = self.linear_growth_init(self.history)
|
||||
elif self.growth == 'flat':
|
||||
k, m = self.flat_growth_init(self.history)
|
||||
elif self.growth == 'logistic':
|
||||
k, m = self.logistic_growth_init(self.history)
|
||||
return ModelParams(
|
||||
k=k,
|
||||
m=m,
|
||||
delta=np.zeros_like(self.changepoints_t),
|
||||
beta=np.zeros(num_total_regressors),
|
||||
sigma_obs=1.0,
|
||||
)
|
||||
|
||||
def fit(self, df, **kwargs):
|
||||
"""Fit the Prophet model.
|
||||
|
||||
|
|
@ -1142,63 +1213,14 @@ class Prophet(object):
|
|||
if self.history is not None:
|
||||
raise Exception('Prophet object can only be fit once. '
|
||||
'Instantiate a new object.')
|
||||
if ('ds' not in df) or ('y' not in df):
|
||||
raise ValueError(
|
||||
'Dataframe must have columns "ds" and "y" with the dates and '
|
||||
'values respectively.'
|
||||
)
|
||||
history = df[df['y'].notnull()].copy()
|
||||
if history.shape[0] < 2:
|
||||
raise ValueError('Dataframe has less than 2 non-NaN rows.')
|
||||
self.history_dates = pd.to_datetime(pd.Series(df['ds'].unique(), name='ds')).sort_values()
|
||||
|
||||
history = self.setup_dataframe(history, initialize_scales=True)
|
||||
self.history = history
|
||||
self.set_auto_seasonalities()
|
||||
seasonal_features, prior_scales, component_cols, modes = (
|
||||
self.make_all_seasonality_features(history))
|
||||
self.train_component_cols = component_cols
|
||||
self.component_modes = modes
|
||||
self.fit_kwargs = deepcopy(kwargs)
|
||||
model_inputs = self.preprocess(df, **kwargs)
|
||||
initial_params = self.calculate_initial_params(model_inputs.K)
|
||||
|
||||
self.set_changepoints()
|
||||
dat = dataclasses.asdict(model_inputs)
|
||||
stan_init = dataclasses.asdict(initial_params)
|
||||
|
||||
trend_indicator = {'linear': 0, 'logistic': 1, 'flat': 2}
|
||||
|
||||
dat = {
|
||||
'T': history.shape[0],
|
||||
'K': seasonal_features.shape[1],
|
||||
'S': len(self.changepoints_t),
|
||||
'y': history['y_scaled'],
|
||||
't': history['t'],
|
||||
't_change': self.changepoints_t,
|
||||
'X': seasonal_features,
|
||||
'sigmas': prior_scales,
|
||||
'tau': self.changepoint_prior_scale,
|
||||
'trend_indicator': trend_indicator[self.growth],
|
||||
's_a': component_cols['additive_terms'],
|
||||
's_m': component_cols['multiplicative_terms'],
|
||||
}
|
||||
|
||||
if self.growth == 'linear':
|
||||
dat['cap'] = np.zeros(self.history.shape[0])
|
||||
kinit = self.linear_growth_init(history)
|
||||
elif self.growth == 'flat':
|
||||
dat['cap'] = np.zeros(self.history.shape[0])
|
||||
kinit = self.flat_growth_init(history)
|
||||
else:
|
||||
dat['cap'] = history['cap_scaled']
|
||||
kinit = self.logistic_growth_init(history)
|
||||
|
||||
stan_init = {
|
||||
'k': kinit[0],
|
||||
'm': kinit[1],
|
||||
'delta': np.zeros(len(self.changepoints_t)),
|
||||
'beta': np.zeros(seasonal_features.shape[1]),
|
||||
'sigma_obs': 1,
|
||||
}
|
||||
|
||||
if history['y'].min() == history['y'].max() and \
|
||||
if self.history['y'].min() == self.history['y'].max() and \
|
||||
(self.growth == 'linear' or self.growth == 'flat'):
|
||||
self.params = stan_init
|
||||
self.params['sigma_obs'] = 1e-9
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
import importlib_resources
|
||||
|
|
@ -17,6 +18,36 @@ logger = logging.getLogger('prophet.models')
|
|||
|
||||
PLATFORM = "win" if platform.platform().startswith("Win") else "unix"
|
||||
|
||||
class TrendIndicator(Enum):
|
||||
LINEAR = 0
|
||||
LOGISTIC = 1
|
||||
FLAT = 2
|
||||
|
||||
@dataclass
|
||||
class ModelInputData:
|
||||
T: int
|
||||
S: int
|
||||
K: int
|
||||
tau: float
|
||||
trend_indicator: int
|
||||
y: Sequence[float] # length T
|
||||
t: Sequence[float] # length T
|
||||
cap: Sequence[float] # length T
|
||||
t_change: Sequence[float] # length S
|
||||
s_a: Sequence[int] # length K
|
||||
s_m: Sequence[int] # length K
|
||||
X: Sequence[Sequence[float]] # shape (T, K)
|
||||
sigmas: Sequence[float] # length K
|
||||
|
||||
@dataclass
|
||||
class ModelParams:
|
||||
k: float
|
||||
m: float
|
||||
delta: Sequence[float] # length S
|
||||
beta: Sequence[float] # length K
|
||||
sigma_obs: float
|
||||
|
||||
|
||||
class IStanBackend(ABC):
|
||||
def __init__(self):
|
||||
self.model = self.load_model()
|
||||
|
|
|
|||
Loading…
Reference in a new issue