diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index abf9a37..aecf016 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -156,6 +156,7 @@ class Prophet(object): self.history = None self.history_dates = None self.train_component_cols = None + self.component_modes = None self.validate_inputs() def validate_inputs(self): @@ -196,8 +197,8 @@ class Prophet(object): raise ValueError('Name cannot contain "_delim_"') reserved_names = [ 'trend', 'additive_terms', 'daily', 'weekly', 'yearly', - 'holidays', 'zeros', 'extra_regressors_additive', - 'extra_regressors_multiplicative', 'yhat', + 'holidays', 'zeros', 'extra_regressors_additive','yhat', + 'extra_regressors_multiplicative', 'multiplicative_terms', ] rn_l = [n + '_lower' for n in reserved_names] rn_u = [n + '_upper' for n in reserved_names] @@ -686,6 +687,8 @@ class Prophet(object): # Add combination components to modes modes[mode].append(mode + '_terms') modes[mode].append('extra_regressors_' + mode) + # After all of the additive/multiplicative groups have been added, + modes[self.seasonality_mode].append('holidays') # Convert to a binary matrix component_cols = pd.crosstab( components['col'], components['component'], @@ -724,8 +727,10 @@ class Prophet(object): Dataframe with components. """ new_comp = components[components['component'].isin(set(group))].copy() - new_comp['component'] = name - components = components.append(new_comp) + group_cols = new_comp['col'].unique() + if len(group_cols) > 0: + new_comp = pd.DataFrame({'component': name, 'col': group_cols}) + components = components.append(new_comp) return components def parse_seasonality_args(self, name, arg, auto_disable, default_order): @@ -920,9 +925,10 @@ class Prophet(object): history = self.setup_dataframe(history, initialize_scales=True) self.history = history self.set_auto_seasonalities() - seasonal_features, prior_scales, component_cols, _ = ( + seasonal_features, prior_scales, component_cols, modes = ( self.make_all_seasonality_features(history)) self.train_component_cols = component_cols + self.component_modes = modes self.set_changepoints() @@ -1131,7 +1137,7 @@ class Prophet(object): ------- Dataframe with seasonal components. """ - seasonal_features, _, component_cols, modes = ( + seasonal_features, _, component_cols, _ = ( self.make_all_seasonality_features(df) ) lower_p = 100 * (1.0 - self.interval_width) / 2 @@ -1143,7 +1149,7 @@ class Prophet(object): beta_c = self.params['beta'] * component_cols[component].values comp = np.matmul(X, beta_c.transpose()) - if component in modes['additive']: + if component in self.component_modes['additive']: comp *= self.y_scale data[component] = np.nanmean(comp, axis=1) data[component + '_lower'] = np.nanpercentile( diff --git a/python/fbprophet/plot.py b/python/fbprophet/plot.py index e98fc3d..2b55723 100644 --- a/python/fbprophet/plot.py +++ b/python/fbprophet/plot.py @@ -186,6 +186,8 @@ def plot_forecast_component( ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2) ax.set_xlabel('ds') ax.set_ylabel(name) + if name in m.component_modes['multiplicative']: + ax = set_y_as_percent(ax) return artists @@ -246,7 +248,9 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0): ax.set_xticks(range(len(days))) ax.set_xticklabels(days) ax.set_xlabel('Day of week') - ax.set_ylabel('weekly ({})'.format(m.seasonalities['weekly']['mode'])) + ax.set_ylabel('weekly') + if m.seasonalities['weekly']['mode'] == 'multiplicative': + ax = set_y_as_percent(ax) return artists @@ -288,7 +292,9 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0): lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x)))) ax.xaxis.set_major_locator(months) ax.set_xlabel('Day of year') - ax.set_ylabel('yearly ({})'.format(m.seasonalities['yearly']['mode'])) + ax.set_ylabel('yearly') + if m.seasonalities['yearly']['mode'] == 'multiplicative': + ax = set_y_as_percent(ax) return artists @@ -338,10 +344,19 @@ def plot_seasonality(m, name, ax=None, uncertainty=True): ax.xaxis.set_major_formatter(FuncFormatter( lambda x, pos=None: fmt_str.format(dt=num2date(x)))) ax.set_xlabel('ds') - ax.set_ylabel('{} ({})'.format(name, m.seasonalities[name]['mode'])) + ax.set_ylabel('{}'.format(name)) + if m.seasonalities[name]['mode'] == 'multiplicative': + ax = set_y_as_percent(ax) return artists +def set_y_as_percent(ax): + yticks = 100 * ax.get_yticks() + yticklabels = ['{0:.4g}%'.format(y) for y in yticks] + ax.set_yticklabels(yticklabels) + return ax + + def add_changepoints_to_plot( ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True, ): diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 993d6d5..579634f 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -678,5 +678,7 @@ class TestProphet(TestCase): self.assertEqual( set(modes['multiplicative']), {'weekly', 'yearly', 'xmas', 'numeric_feature', - 'multiplicative_terms', 'extra_regressors_multiplicative'}, + 'multiplicative_terms', 'extra_regressors_multiplicative', + 'holidays', + }, )