Show multiplicative seasonality as percent in plots (Py)

This commit is contained in:
Ben Letham 2018-05-11 17:54:29 -07:00
parent 8d8c5b41ce
commit f1e24d3c2c
3 changed files with 34 additions and 11 deletions

View file

@ -156,6 +156,7 @@ class Prophet(object):
self.history = None
self.history_dates = None
self.train_component_cols = None
self.component_modes = None
self.validate_inputs()
def validate_inputs(self):
@ -196,8 +197,8 @@ class Prophet(object):
raise ValueError('Name cannot contain "_delim_"')
reserved_names = [
'trend', 'additive_terms', 'daily', 'weekly', 'yearly',
'holidays', 'zeros', 'extra_regressors_additive',
'extra_regressors_multiplicative', 'yhat',
'holidays', 'zeros', 'extra_regressors_additive','yhat',
'extra_regressors_multiplicative', 'multiplicative_terms',
]
rn_l = [n + '_lower' for n in reserved_names]
rn_u = [n + '_upper' for n in reserved_names]
@ -686,6 +687,8 @@ class Prophet(object):
# Add combination components to modes
modes[mode].append(mode + '_terms')
modes[mode].append('extra_regressors_' + mode)
# After all of the additive/multiplicative groups have been added,
modes[self.seasonality_mode].append('holidays')
# Convert to a binary matrix
component_cols = pd.crosstab(
components['col'], components['component'],
@ -724,8 +727,10 @@ class Prophet(object):
Dataframe with components.
"""
new_comp = components[components['component'].isin(set(group))].copy()
new_comp['component'] = name
components = components.append(new_comp)
group_cols = new_comp['col'].unique()
if len(group_cols) > 0:
new_comp = pd.DataFrame({'component': name, 'col': group_cols})
components = components.append(new_comp)
return components
def parse_seasonality_args(self, name, arg, auto_disable, default_order):
@ -920,9 +925,10 @@ class Prophet(object):
history = self.setup_dataframe(history, initialize_scales=True)
self.history = history
self.set_auto_seasonalities()
seasonal_features, prior_scales, component_cols, _ = (
seasonal_features, prior_scales, component_cols, modes = (
self.make_all_seasonality_features(history))
self.train_component_cols = component_cols
self.component_modes = modes
self.set_changepoints()
@ -1131,7 +1137,7 @@ class Prophet(object):
-------
Dataframe with seasonal components.
"""
seasonal_features, _, component_cols, modes = (
seasonal_features, _, component_cols, _ = (
self.make_all_seasonality_features(df)
)
lower_p = 100 * (1.0 - self.interval_width) / 2
@ -1143,7 +1149,7 @@ class Prophet(object):
beta_c = self.params['beta'] * component_cols[component].values
comp = np.matmul(X, beta_c.transpose())
if component in modes['additive']:
if component in self.component_modes['additive']:
comp *= self.y_scale
data[component] = np.nanmean(comp, axis=1)
data[component + '_lower'] = np.nanpercentile(

View file

@ -186,6 +186,8 @@ def plot_forecast_component(
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
ax.set_xlabel('ds')
ax.set_ylabel(name)
if name in m.component_modes['multiplicative']:
ax = set_y_as_percent(ax)
return artists
@ -246,7 +248,9 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
ax.set_xticks(range(len(days)))
ax.set_xticklabels(days)
ax.set_xlabel('Day of week')
ax.set_ylabel('weekly ({})'.format(m.seasonalities['weekly']['mode']))
ax.set_ylabel('weekly')
if m.seasonalities['weekly']['mode'] == 'multiplicative':
ax = set_y_as_percent(ax)
return artists
@ -288,7 +292,9 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
ax.xaxis.set_major_locator(months)
ax.set_xlabel('Day of year')
ax.set_ylabel('yearly ({})'.format(m.seasonalities['yearly']['mode']))
ax.set_ylabel('yearly')
if m.seasonalities['yearly']['mode'] == 'multiplicative':
ax = set_y_as_percent(ax)
return artists
@ -338,10 +344,19 @@ def plot_seasonality(m, name, ax=None, uncertainty=True):
ax.xaxis.set_major_formatter(FuncFormatter(
lambda x, pos=None: fmt_str.format(dt=num2date(x))))
ax.set_xlabel('ds')
ax.set_ylabel('{} ({})'.format(name, m.seasonalities[name]['mode']))
ax.set_ylabel('{}'.format(name))
if m.seasonalities[name]['mode'] == 'multiplicative':
ax = set_y_as_percent(ax)
return artists
def set_y_as_percent(ax):
yticks = 100 * ax.get_yticks()
yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
ax.set_yticklabels(yticklabels)
return ax
def add_changepoints_to_plot(
ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
):

View file

@ -678,5 +678,7 @@ class TestProphet(TestCase):
self.assertEqual(
set(modes['multiplicative']),
{'weekly', 'yearly', 'xmas', 'numeric_feature',
'multiplicative_terms', 'extra_regressors_multiplicative'},
'multiplicative_terms', 'extra_regressors_multiplicative',
'holidays',
},
)