mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-18 21:21:22 +00:00
Lint fixes
This commit is contained in:
parent
cc3238acb7
commit
4523315ffc
2 changed files with 41 additions and 38 deletions
|
|
@ -66,7 +66,7 @@ class Prophet(object):
|
|||
fluctuations, smaller values dampen the seasonality. Can be specified
|
||||
for individual seasonalities using add_seasonality.
|
||||
holidays_prior_scale: Parameter modulating the strength of the holiday
|
||||
components model, unless overriden in the holidays input.
|
||||
components model, unless overridden in the holidays input.
|
||||
changepoint_prior_scale: Parameter modulating the flexibility of the
|
||||
automatic changepoint selection. Large values will allow many
|
||||
changepoints, small values will allow few changepoints.
|
||||
|
|
@ -115,8 +115,8 @@ class Prophet(object):
|
|||
if holidays is not None:
|
||||
if not (
|
||||
isinstance(holidays, pd.DataFrame)
|
||||
and 'ds' in holidays
|
||||
and 'holiday' in holidays
|
||||
and 'ds' in holidays # noqa W503
|
||||
and 'holiday' in holidays # noqa W503
|
||||
):
|
||||
raise ValueError("holidays must be a DataFrame with 'ds' and "
|
||||
"'holiday' columns.")
|
||||
|
|
@ -232,32 +232,7 @@ class Prophet(object):
|
|||
df = df.sort_values('ds')
|
||||
df.reset_index(inplace=True, drop=True)
|
||||
|
||||
if initialize_scales:
|
||||
if self.growth == 'logistic' and 'floor' in df:
|
||||
self.logistic_floor = True
|
||||
floor = df['floor']
|
||||
else:
|
||||
floor = 0.
|
||||
self.y_scale = (df['y'] - floor).abs().max()
|
||||
if self.y_scale == 0:
|
||||
self.y_scale = 1
|
||||
self.start = df['ds'].min()
|
||||
self.t_scale = df['ds'].max() - self.start
|
||||
for name, props in self.extra_regressors.items():
|
||||
standardize = props['standardize']
|
||||
if standardize == 'auto':
|
||||
if set(df[name].unique()) == set([1, 0]):
|
||||
# Don't standardize binary variables.
|
||||
standardize = False
|
||||
else:
|
||||
standardize = True
|
||||
if standardize:
|
||||
mu = df[name].mean()
|
||||
std = df[name].std()
|
||||
if std == 0:
|
||||
std = mu
|
||||
self.extra_regressors[name]['mu'] = mu
|
||||
self.extra_regressors[name]['std'] = std
|
||||
self.initialize_scales(initialize_scales, df)
|
||||
|
||||
if self.logistic_floor:
|
||||
if 'floor' not in df:
|
||||
|
|
@ -279,6 +254,35 @@ class Prophet(object):
|
|||
raise ValueError('Found NaN in column ' + name)
|
||||
return df
|
||||
|
||||
def initialize_scales(self, initialize_scales, df):
|
||||
if not initialize_scales:
|
||||
return
|
||||
if self.growth == 'logistic' and 'floor' in df:
|
||||
self.logistic_floor = True
|
||||
floor = df['floor']
|
||||
else:
|
||||
floor = 0.
|
||||
self.y_scale = (df['y'] - floor).abs().max()
|
||||
if self.y_scale == 0:
|
||||
self.y_scale = 1
|
||||
self.start = df['ds'].min()
|
||||
self.t_scale = df['ds'].max() - self.start
|
||||
for name, props in self.extra_regressors.items():
|
||||
standardize = props['standardize']
|
||||
if standardize == 'auto':
|
||||
if set(df[name].unique()) == set([1, 0]):
|
||||
# Don't standardize binary variables.
|
||||
standardize = False
|
||||
else:
|
||||
standardize = True
|
||||
if standardize:
|
||||
mu = df[name].mean()
|
||||
std = df[name].std()
|
||||
if std == 0:
|
||||
std = mu
|
||||
self.extra_regressors[name]['mu'] = mu
|
||||
self.extra_regressors[name]['std'] = std
|
||||
|
||||
def set_changepoints(self):
|
||||
"""Set changepoints
|
||||
|
||||
|
|
@ -422,7 +426,7 @@ class Prophet(object):
|
|||
if ps <= 0:
|
||||
raise ValueError('Prior scale must be > 0')
|
||||
prior_scales[row.holiday] = ps
|
||||
|
||||
|
||||
for offset in range(lw, uw + 1):
|
||||
occurrence = dt + timedelta(days=offset)
|
||||
try:
|
||||
|
|
@ -918,7 +922,7 @@ class Prophet(object):
|
|||
for i, t_s in enumerate(changepoint_ts):
|
||||
gammas[i] = (
|
||||
(t_s - m - np.sum(gammas))
|
||||
* (1 - k_cum[i] / k_cum[i + 1])
|
||||
* (1 - k_cum[i] / k_cum[i + 1]) # noqa W503
|
||||
)
|
||||
# Get cumulative rate and offset at each t
|
||||
k_t = k * np.ones_like(t)
|
||||
|
|
@ -997,7 +1001,7 @@ class Prophet(object):
|
|||
comp_features = X[:, cols]
|
||||
comp = (
|
||||
np.matmul(comp_features, comp_beta.transpose())
|
||||
* self.y_scale
|
||||
* self.y_scale # noqa W503
|
||||
)
|
||||
data[component] = np.nanmean(comp, axis=1)
|
||||
data[component + '_lower'] = np.nanpercentile(comp, lower_p,
|
||||
|
|
@ -1025,7 +1029,6 @@ class Prophet(object):
|
|||
components = components.append(new_comp)
|
||||
return components
|
||||
|
||||
|
||||
def sample_posterior_predictive(self, df):
|
||||
"""Prophet posterior predictive samples.
|
||||
|
||||
|
|
@ -1237,7 +1240,7 @@ class Prophet(object):
|
|||
ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap :
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
ax.fill_between(fcst['ds'].values, fcst['yhat_lower'],
|
||||
|
|
@ -1333,7 +1336,7 @@ class Prophet(object):
|
|||
artists += ax.plot(fcst['ds'].values, fcst[name], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
artists += ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap :
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
|
|
|
|||
|
|
@ -521,15 +521,15 @@ class TestProphet(TestCase):
|
|||
fcst['extra_regressors'][0],
|
||||
fcst['numeric_feature'][0] + fcst['binary_feature2'][0],
|
||||
)
|
||||
self.assertEqual(
|
||||
self.assertAlmostEqual(
|
||||
fcst['seasonalities'][0],
|
||||
fcst['yearly'][0] + fcst['weekly'][0],
|
||||
)
|
||||
self.assertEqual(
|
||||
self.assertAlmostEqual(
|
||||
fcst['seasonal'][0],
|
||||
fcst['seasonalities'][0] + fcst['extra_regressors'][0],
|
||||
)
|
||||
self.assertEqual(
|
||||
self.assertAlmostEqual(
|
||||
fcst['yhat'][0],
|
||||
fcst['trend'][0] + fcst['seasonal'][0],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue