Fix copy with extra seasonalities / regressors Py

This commit is contained in:
bl 2017-11-04 18:07:24 -07:00
parent 9ffa0d8790
commit 5dbffbaa18
2 changed files with 35 additions and 10 deletions

View file

@ -11,6 +11,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from collections import defaultdict
from copy import deepcopy
from datetime import timedelta
import logging
@ -1515,6 +1516,9 @@ class Prophet(object):
-------
Prophet class object with the same parameter with model variable
"""
if self.history is None:
raise Exception('This is for copying a fitted Prophet object.')
if self.specified_changepoints:
changepoints = self.changepoints
if cutoff is not None:
@ -1523,18 +1527,23 @@ class Prophet(object):
else:
changepoints = None
return Prophet(
# Auto seasonalities are set to False because they are already set in
# self.seasonalities.
m = Prophet(
growth=self.growth,
n_changepoints=self.n_changepoints,
changepoints=changepoints,
yearly_seasonality=self.yearly_seasonality,
weekly_seasonality=self.weekly_seasonality,
daily_seasonality=self.daily_seasonality,
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=False,
holidays=self.holidays,
seasonality_prior_scale=self.seasonality_prior_scale,
changepoint_prior_scale=self.changepoint_prior_scale,
holidays_prior_scale=self.holidays_prior_scale,
mcmc_samples=self.mcmc_samples,
interval_width=self.interval_width,
uncertainty_samples=self.uncertainty_samples
uncertainty_samples=self.uncertainty_samples,
)
m.extra_regressors = deepcopy(self.extra_regressors)
m.seasonalities = deepcopy(self.seasonalities)
return m

View file

@ -555,6 +555,9 @@ class TestProphet(TestCase):
m.fit(df.copy())
def test_copy(self):
df = DATA.copy()
df['cap'] = 200.
df['binary_feature'] = [0] * 255 + [1] * 255
# These values are created except for its default values
holiday = pd.DataFrame(
{'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})
@ -576,13 +579,22 @@ class TestProphet(TestCase):
# Values should be copied correctly
for product in products:
m1 = Prophet(*product)
m1.history = m1.setup_dataframe(
df.copy(), initialize_scales=True)
m1.set_auto_seasonalities()
m2 = m1.copy()
self.assertEqual(m1.growth, m2.growth)
self.assertEqual(m1.n_changepoints, m2.n_changepoints)
self.assertEqual(m1.changepoints, m2.changepoints)
self.assertEqual(m1.yearly_seasonality, m2.yearly_seasonality)
self.assertEqual(m1.weekly_seasonality, m2.weekly_seasonality)
self.assertEqual(m1.daily_seasonality, m2.daily_seasonality)
self.assertEqual(False, m2.yearly_seasonality)
self.assertEqual(False, m2.weekly_seasonality)
self.assertEqual(False, m2.daily_seasonality)
self.assertEqual(
m1.yearly_seasonality, 'yearly' in m2.seasonalities)
self.assertEqual(
m1.weekly_seasonality, 'weekly' in m2.seasonalities)
self.assertEqual(
m1.daily_seasonality, 'daily' in m2.seasonalities)
if m1.holidays is None:
self.assertEqual(m1.holidays, m2.holidays)
else:
@ -594,11 +606,15 @@ class TestProphet(TestCase):
self.assertEqual(m1.interval_width, m2.interval_width)
self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
# Check for cutoff
# Check for cutoff and custom seasonality and extra regressors
changepoints = pd.date_range('2012-06-15', '2012-09-15')
cutoff = pd.Timestamp('2012-07-25')
m1 = Prophet(changepoints=changepoints)
m1.fit(DATA)
m1.add_seasonality('custom', 10, 5)
m1.add_regressor('binary_feature')
m1.fit(df)
m2 = m1.copy(cutoff=cutoff)
changepoints = changepoints[changepoints <= cutoff]
self.assertTrue((changepoints == m2.changepoints).all())
self.assertTrue('custom' in m2.seasonalities)
self.assertTrue('binary_feature' in m2.extra_regressors)