mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-19 21:32:29 +00:00
Fix copy with extra seasonalities / regressors Py
This commit is contained in:
parent
9ffa0d8790
commit
5dbffbaa18
2 changed files with 35 additions and 10 deletions
|
|
@ -11,6 +11,7 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
|
|
@ -1515,6 +1516,9 @@ class Prophet(object):
|
|||
-------
|
||||
Prophet class object with the same parameter with model variable
|
||||
"""
|
||||
if self.history is None:
|
||||
raise Exception('This is for copying a fitted Prophet object.')
|
||||
|
||||
if self.specified_changepoints:
|
||||
changepoints = self.changepoints
|
||||
if cutoff is not None:
|
||||
|
|
@ -1523,18 +1527,23 @@ class Prophet(object):
|
|||
else:
|
||||
changepoints = None
|
||||
|
||||
return Prophet(
|
||||
# Auto seasonalities are set to False because they are already set in
|
||||
# self.seasonalities.
|
||||
m = Prophet(
|
||||
growth=self.growth,
|
||||
n_changepoints=self.n_changepoints,
|
||||
changepoints=changepoints,
|
||||
yearly_seasonality=self.yearly_seasonality,
|
||||
weekly_seasonality=self.weekly_seasonality,
|
||||
daily_seasonality=self.daily_seasonality,
|
||||
yearly_seasonality=False,
|
||||
weekly_seasonality=False,
|
||||
daily_seasonality=False,
|
||||
holidays=self.holidays,
|
||||
seasonality_prior_scale=self.seasonality_prior_scale,
|
||||
changepoint_prior_scale=self.changepoint_prior_scale,
|
||||
holidays_prior_scale=self.holidays_prior_scale,
|
||||
mcmc_samples=self.mcmc_samples,
|
||||
interval_width=self.interval_width,
|
||||
uncertainty_samples=self.uncertainty_samples
|
||||
uncertainty_samples=self.uncertainty_samples,
|
||||
)
|
||||
m.extra_regressors = deepcopy(self.extra_regressors)
|
||||
m.seasonalities = deepcopy(self.seasonalities)
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -555,6 +555,9 @@ class TestProphet(TestCase):
|
|||
m.fit(df.copy())
|
||||
|
||||
def test_copy(self):
|
||||
df = DATA.copy()
|
||||
df['cap'] = 200.
|
||||
df['binary_feature'] = [0] * 255 + [1] * 255
|
||||
# These values are created except for its default values
|
||||
holiday = pd.DataFrame(
|
||||
{'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})
|
||||
|
|
@ -576,13 +579,22 @@ class TestProphet(TestCase):
|
|||
# Values should be copied correctly
|
||||
for product in products:
|
||||
m1 = Prophet(*product)
|
||||
m1.history = m1.setup_dataframe(
|
||||
df.copy(), initialize_scales=True)
|
||||
m1.set_auto_seasonalities()
|
||||
m2 = m1.copy()
|
||||
self.assertEqual(m1.growth, m2.growth)
|
||||
self.assertEqual(m1.n_changepoints, m2.n_changepoints)
|
||||
self.assertEqual(m1.changepoints, m2.changepoints)
|
||||
self.assertEqual(m1.yearly_seasonality, m2.yearly_seasonality)
|
||||
self.assertEqual(m1.weekly_seasonality, m2.weekly_seasonality)
|
||||
self.assertEqual(m1.daily_seasonality, m2.daily_seasonality)
|
||||
self.assertEqual(False, m2.yearly_seasonality)
|
||||
self.assertEqual(False, m2.weekly_seasonality)
|
||||
self.assertEqual(False, m2.daily_seasonality)
|
||||
self.assertEqual(
|
||||
m1.yearly_seasonality, 'yearly' in m2.seasonalities)
|
||||
self.assertEqual(
|
||||
m1.weekly_seasonality, 'weekly' in m2.seasonalities)
|
||||
self.assertEqual(
|
||||
m1.daily_seasonality, 'daily' in m2.seasonalities)
|
||||
if m1.holidays is None:
|
||||
self.assertEqual(m1.holidays, m2.holidays)
|
||||
else:
|
||||
|
|
@ -594,11 +606,15 @@ class TestProphet(TestCase):
|
|||
self.assertEqual(m1.interval_width, m2.interval_width)
|
||||
self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
|
||||
|
||||
# Check for cutoff
|
||||
# Check for cutoff and custom seasonality and extra regressors
|
||||
changepoints = pd.date_range('2012-06-15', '2012-09-15')
|
||||
cutoff = pd.Timestamp('2012-07-25')
|
||||
m1 = Prophet(changepoints=changepoints)
|
||||
m1.fit(DATA)
|
||||
m1.add_seasonality('custom', 10, 5)
|
||||
m1.add_regressor('binary_feature')
|
||||
m1.fit(df)
|
||||
m2 = m1.copy(cutoff=cutoff)
|
||||
changepoints = changepoints[changepoints <= cutoff]
|
||||
self.assertTrue((changepoints == m2.changepoints).all())
|
||||
self.assertTrue('custom' in m2.seasonalities)
|
||||
self.assertTrue('binary_feature' in m2.extra_regressors)
|
||||
|
|
|
|||
Loading…
Reference in a new issue