mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-08 00:23:09 +00:00
Generalize seasonality representation (Python)
This commit is contained in:
parent
825108b226
commit
b3017c025f
2 changed files with 88 additions and 71 deletions
|
|
@ -78,7 +78,6 @@ class Prophet(object):
|
|||
parameters, which will include uncertainty in seasonality.
|
||||
uncertainty_samples: Number of simulated draws used to estimate
|
||||
uncertainty intervals.
|
||||
daily_seasonality: Boolean, fit daily seasonality
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -96,7 +95,6 @@ class Prophet(object):
|
|||
mcmc_samples=0,
|
||||
interval_width=0.80,
|
||||
uncertainty_samples=1000,
|
||||
daily_seasonality=False,
|
||||
):
|
||||
self.growth = growth
|
||||
|
||||
|
|
@ -134,6 +132,7 @@ class Prophet(object):
|
|||
self.y_scale = None
|
||||
self.t_scale = None
|
||||
self.changepoints_t = None
|
||||
self.seasonalities = {}
|
||||
self.stan_fit = None
|
||||
self.params = {}
|
||||
self.history = None
|
||||
|
|
@ -358,81 +357,91 @@ class Prophet(object):
|
|||
# Add a column of zeros in case no seasonality is used.
|
||||
pd.DataFrame({'zeros': np.zeros(df.shape[0])})
|
||||
]
|
||||
|
||||
# Seasonality features
|
||||
if self.yearly_seasonality > 0:
|
||||
for name, (period, series_order) in self.seasonalities.items():
|
||||
seasonal_features.append(self.make_seasonality_features(
|
||||
df['ds'],
|
||||
365.25,
|
||||
self.yearly_seasonality,
|
||||
'yearly',
|
||||
))
|
||||
|
||||
if self.weekly_seasonality > 0:
|
||||
seasonal_features.append(self.make_seasonality_features(
|
||||
df['ds'],
|
||||
7,
|
||||
self.weekly_seasonality,
|
||||
'weekly',
|
||||
))
|
||||
|
||||
if self.daily_seasonality > 0:
|
||||
seasonal_features.append(self.make_seasonality_features(
|
||||
df['ds'],
|
||||
1,
|
||||
self.daily_seasonality,
|
||||
'daily',
|
||||
period,
|
||||
series_order,
|
||||
name,
|
||||
))
|
||||
|
||||
if self.holidays is not None:
|
||||
seasonal_features.append(self.make_holiday_features(df['ds']))
|
||||
return pd.concat(seasonal_features, axis=1)
|
||||
|
||||
def parse_seasonality_args(self, name, arg, auto_disable, default_order):
|
||||
"""Get number of fourier components for built-in seasonalities.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: string name of the seasonality component.
|
||||
arg: 'auto', True, False, or number of fourier components as provided.
|
||||
auto_disable: bool if seasonality should be disabled when 'auto'.
|
||||
default_order: int default fourier order
|
||||
|
||||
Returns
|
||||
-------
|
||||
Number of fourier components, or 0 for disabled.
|
||||
"""
|
||||
if arg == 'auto':
|
||||
fourier_order = 0
|
||||
if name in self.seasonalities:
|
||||
logger.info(
|
||||
'Found custom seasonality named "{name}", '
|
||||
'disabling built-in {name} seasonality.'.format(name=name)
|
||||
)
|
||||
elif auto_disable:
|
||||
logger.info(
|
||||
'Disabling {name} seasonality. Run prophet with '
|
||||
'{name}_seasonality=True to override this.'.format(
|
||||
name=name)
|
||||
)
|
||||
else:
|
||||
fourier_order = default_order
|
||||
elif arg is True:
|
||||
fourier_order = default_order
|
||||
elif arg is False:
|
||||
fourier_order = 0
|
||||
else:
|
||||
fourier_order = int(arg)
|
||||
return fourier_order
|
||||
|
||||
def set_auto_seasonalities(self):
|
||||
"""Set seasonalities that were left on auto.
|
||||
|
||||
Turns on yearly seasonality if there is >=2 years of history.
|
||||
Turns on weekly seasonality if there is >=2 weeks of history, and the
|
||||
spacing between dates in the history is <7 days.
|
||||
Turns on daily seasonality if there is >=2 days of history, and the
|
||||
spacing between dates in the history is <1 day.
|
||||
"""
|
||||
first = self.history['ds'].min()
|
||||
last = self.history['ds'].max()
|
||||
if self.yearly_seasonality == 'auto':
|
||||
if last - first < pd.Timedelta(days=730):
|
||||
self.yearly_seasonality = 0
|
||||
logger.info('Disabling yearly seasonality. Run prophet with '
|
||||
'yearly_seasonality=True to override this.')
|
||||
else:
|
||||
self.yearly_seasonality = 10
|
||||
elif self.yearly_seasonality is True:
|
||||
self.yearly_seasonality = 10
|
||||
|
||||
if self.weekly_seasonality == 'auto':
|
||||
dt = self.history['ds'].diff()
|
||||
min_dt = dt.iloc[dt.nonzero()[0]].min()
|
||||
if ((last - first < pd.Timedelta(weeks=2)) or
|
||||
(min_dt >= pd.Timedelta(weeks=1))):
|
||||
self.weekly_seasonality = 0
|
||||
logger.info('Disabling weekly seasonality. Run prophet with '
|
||||
'weekly_seasonality=True to override this.')
|
||||
else:
|
||||
self.weekly_seasonality = 3
|
||||
elif self.weekly_seasonality is True:
|
||||
self.weekly_seasonality = 3
|
||||
|
||||
if self.daily_seasonality == 'auto':
|
||||
# disabled by default but if the average difference is <1 day
|
||||
# then we assume there is intra-day modeling
|
||||
dt = self.history['ds'].diff()
|
||||
min_dt = dt.iloc[dt.nonzero()[0]].min()
|
||||
if (min_dt< pd.Timedelta(days=1)):
|
||||
self.daily_seasonality = 4
|
||||
logger.info('Enabling daily seasonality. Run prophet with '
|
||||
'daily_seasonality=False to override this.')
|
||||
else:
|
||||
self.daily_seasonality = 0
|
||||
elif self.daily_seasonality is True:
|
||||
self.daily_seasonality = 4
|
||||
dt = self.history['ds'].diff()
|
||||
min_dt = dt.iloc[dt.nonzero()[0]].min()
|
||||
|
||||
# Yearly seasonality
|
||||
yearly_disable = last - first < pd.Timedelta(days=730)
|
||||
fourier_order = self.parse_seasonality_args(
|
||||
'yearly', self.yearly_seasonality, yearly_disable, 10)
|
||||
if fourier_order > 0:
|
||||
self.seasonalities['yearly'] = (365.25, fourier_order)
|
||||
|
||||
# Weekly seasonality
|
||||
weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
|
||||
(min_dt >= pd.Timedelta(weeks=1)))
|
||||
fourier_order = self.parse_seasonality_args(
|
||||
'weekly', self.weekly_seasonality, weekly_disable, 3)
|
||||
if fourier_order > 0:
|
||||
self.seasonalities['weekly'] = (7, fourier_order)
|
||||
|
||||
# Daily seasonality
|
||||
daily_disable = ((last - first < pd.Timedelta(days=2)) or
|
||||
(min_dt >= pd.Timedelta(days=1)))
|
||||
fourier_order = self.parse_seasonality_args(
|
||||
'daily', self.daily_seasonality, daily_disable, 4)
|
||||
if fourier_order > 0:
|
||||
self.seasonalities['daily'] = (1, fourier_order)
|
||||
|
||||
@staticmethod
|
||||
def linear_growth_init(df):
|
||||
|
|
|
|||
|
|
@ -252,43 +252,51 @@ class TestProphet(TestCase):
|
|||
self.assertEqual(future.iloc[i]['ds'], correct[i])
|
||||
|
||||
def test_auto_weekly_seasonality(self):
|
||||
# Should be True
|
||||
# Should be enabled
|
||||
N = 15
|
||||
train = DATA.head(N)
|
||||
m = Prophet()
|
||||
self.assertEqual(m.weekly_seasonality, 'auto')
|
||||
m.fit(train)
|
||||
self.assertEqual(m.weekly_seasonality, True)
|
||||
# Should be False due to too short history
|
||||
self.assertIn('weekly', m.seasonalities)
|
||||
self.assertEqual(m.seasonalities['weekly'], (7, 3))
|
||||
# Should be disabled due to too short history
|
||||
N = 9
|
||||
train = DATA.head(N)
|
||||
m = Prophet()
|
||||
m.fit(train)
|
||||
self.assertEqual(m.weekly_seasonality, False)
|
||||
self.assertNotIn('weekly', m.seasonalities)
|
||||
m = Prophet(weekly_seasonality=True)
|
||||
m.fit(train)
|
||||
self.assertEqual(m.weekly_seasonality, True)
|
||||
self.assertIn('weekly', m.seasonalities)
|
||||
# Should be False due to weekly spacing
|
||||
train = DATA.iloc[::7, :]
|
||||
m = Prophet()
|
||||
m.fit(train)
|
||||
self.assertEqual(m.weekly_seasonality, False)
|
||||
self.assertNotIn('weekly', m.seasonalities)
|
||||
m = Prophet(weekly_seasonality=2)
|
||||
m.fit(DATA)
|
||||
self.assertEqual(m.seasonalities['weekly'], (7, 2))
|
||||
|
||||
def test_auto_yearly_seasonality(self):
|
||||
# Should be True
|
||||
# Should be enabled
|
||||
m = Prophet()
|
||||
self.assertEqual(m.yearly_seasonality, 'auto')
|
||||
m.fit(DATA)
|
||||
self.assertEqual(m.yearly_seasonality, True)
|
||||
# Should be False due to too short history
|
||||
self.assertIn('yearly', m.seasonalities)
|
||||
self.assertEqual(m.seasonalities['yearly'], (365.25, 10))
|
||||
# Should be disabled due to too short history
|
||||
N = 240
|
||||
train = DATA.head(N)
|
||||
m = Prophet()
|
||||
m.fit(train)
|
||||
self.assertEqual(m.yearly_seasonality, False)
|
||||
self.assertNotIn('yearly', m.seasonalities)
|
||||
m = Prophet(yearly_seasonality=True)
|
||||
m.fit(train)
|
||||
self.assertEqual(m.yearly_seasonality, True)
|
||||
self.assertIn('yearly', m.seasonalities)
|
||||
m = Prophet(yearly_seasonality=7)
|
||||
m.fit(DATA)
|
||||
self.assertEqual(m.seasonalities['yearly'], (365.25, 7))
|
||||
|
||||
|
||||
DATA = pd.read_csv(StringIO("""
|
||||
|
|
|
|||
Loading…
Reference in a new issue