mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-18 21:21:22 +00:00
Fix type casting for extra regressors, and a shape issue
This commit is contained in:
parent
91917df8f0
commit
13d96cff8f
2 changed files with 17 additions and 22 deletions
|
|
@ -357,6 +357,10 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|||
if (!(name %in% colnames(df))) {
|
||||
stop('Regressor "', name, '" missing from dataframe')
|
||||
}
|
||||
df[[name]] <- as.numeric(df[[name]])
|
||||
if (anyNA(df[[name]])) {
|
||||
stop('Found NaN in column ', name)
|
||||
}
|
||||
}
|
||||
|
||||
df <- df %>%
|
||||
|
|
@ -386,12 +390,8 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|||
}
|
||||
|
||||
for (name in names(m$extra_regressors)) {
|
||||
df[[name]] <- as.numeric(df[[name]])
|
||||
props <- m$extra_regressors[[name]]
|
||||
df[[name]] <- (df[[name]] - props$mu) / props$std
|
||||
if (anyNA(df[[name]])) {
|
||||
stop('Found NaN in column ', name)
|
||||
}
|
||||
}
|
||||
return(list("m" = m, "df" = df))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from datetime import timedelta
|
|||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pystan # noqa F401
|
||||
|
||||
|
||||
from fbprophet.diagnostics import prophet_copy
|
||||
|
|
@ -33,11 +34,6 @@ from fbprophet.plot import (
|
|||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pystan # noqa F401
|
||||
except ImportError:
|
||||
logger.exception('You cannot run fbprophet without pystan installed')
|
||||
|
||||
|
||||
class Prophet(object):
|
||||
"""Prophet forecaster.
|
||||
|
|
@ -160,13 +156,13 @@ class Prophet(object):
|
|||
raise ValueError("Parameter 'changepoint_range' must be in [0, 1]")
|
||||
if self.holidays is not None:
|
||||
if not (
|
||||
isinstance(holidays, pd.DataFrame)
|
||||
and 'ds' in holidays # noqa W503
|
||||
and 'holiday' in holidays # noqa W503
|
||||
isinstance(self.holidays, pd.DataFrame)
|
||||
and 'ds' in self.holidays # noqa W503
|
||||
and 'holiday' in self.holidays # noqa W503
|
||||
):
|
||||
raise ValueError("holidays must be a DataFrame with 'ds' and "
|
||||
"'holiday' columns.")
|
||||
holidays['ds'] = pd.to_datetime(holidays['ds'])
|
||||
self.holidays['ds'] = pd.to_datetime(self.holidays['ds'])
|
||||
has_lower = 'lower_window' in self.holidays
|
||||
has_upper = 'upper_window' in self.holidays
|
||||
if has_lower + has_upper == 1:
|
||||
|
|
@ -253,6 +249,9 @@ class Prophet(object):
|
|||
if name not in df:
|
||||
raise ValueError(
|
||||
'Regressor "{}" missing from dataframe'.format(name))
|
||||
df[name] = pd.to_numeric(df[name])
|
||||
if df[name].isnull().any():
|
||||
raise ValueError('Found NaN in column ' + name)
|
||||
|
||||
df = df.sort_values('ds')
|
||||
df.reset_index(inplace=True, drop=True)
|
||||
|
|
@ -277,10 +276,7 @@ class Prophet(object):
|
|||
df['y_scaled'] = (df['y'] - df['floor']) / self.y_scale
|
||||
|
||||
for name, props in self.extra_regressors.items():
|
||||
df[name] = pd.to_numeric(df[name])
|
||||
df[name] = ((df[name] - props['mu']) / props['std'])
|
||||
if df[name].isnull().any():
|
||||
raise ValueError('Found NaN in column ' + name)
|
||||
return df
|
||||
|
||||
def initialize_scales(self, initialize_scales, df):
|
||||
|
|
@ -1092,6 +1088,9 @@ class Prophet(object):
|
|||
)
|
||||
for par in stan_fit.model_pars:
|
||||
self.params[par] = stan_fit[par]
|
||||
# Shape vector parameters
|
||||
if par in ['delta', 'beta'] and len(self.params[par].shape) < 2:
|
||||
self.params[par] = self.params[par].reshape((-1, 1))
|
||||
else:
|
||||
try:
|
||||
params = model.optimizing(
|
||||
|
|
@ -1102,12 +1101,8 @@ class Prophet(object):
|
|||
dat, init=stan_init, iter=1e4, algorithm='Newton',
|
||||
**kwargs
|
||||
)
|
||||
self.params = params
|
||||
|
||||
# Reshape all parameters to matrix shape
|
||||
for p, v in self.params.items():
|
||||
if len(v.shape) == 1:
|
||||
self.params[p] = v.reshape((-1, 1))
|
||||
for par in params:
|
||||
self.params[par] = params[par].reshape((1, -1))
|
||||
|
||||
# If no changepoints were requested, replace delta with 0s
|
||||
if len(self.changepoints) == 0:
|
||||
|
|
|
|||
Loading…
Reference in a new issue