Generalize seasonality representation (Python)

This commit is contained in:
bl 2017-07-04 11:06:03 -07:00
parent 825108b226
commit b3017c025f
2 changed files with 88 additions and 71 deletions

View file

@ -78,7 +78,6 @@ class Prophet(object):
parameters, which will include uncertainty in seasonality.
uncertainty_samples: Number of simulated draws used to estimate
uncertainty intervals.
daily_seasonality: Boolean, fit daily seasonality
"""
def __init__(
@ -96,7 +95,6 @@ class Prophet(object):
mcmc_samples=0,
interval_width=0.80,
uncertainty_samples=1000,
daily_seasonality=False,
):
self.growth = growth
@ -134,6 +132,7 @@ class Prophet(object):
self.y_scale = None
self.t_scale = None
self.changepoints_t = None
self.seasonalities = {}
self.stan_fit = None
self.params = {}
self.history = None
@ -358,81 +357,91 @@ class Prophet(object):
# Add a column of zeros in case no seasonality is used.
pd.DataFrame({'zeros': np.zeros(df.shape[0])})
]
# Seasonality features
if self.yearly_seasonality > 0:
for name, (period, series_order) in self.seasonalities.items():
seasonal_features.append(self.make_seasonality_features(
df['ds'],
365.25,
self.yearly_seasonality,
'yearly',
))
if self.weekly_seasonality > 0:
seasonal_features.append(self.make_seasonality_features(
df['ds'],
7,
self.weekly_seasonality,
'weekly',
))
if self.daily_seasonality > 0:
seasonal_features.append(self.make_seasonality_features(
df['ds'],
1,
self.daily_seasonality,
'daily',
period,
series_order,
name,
))
if self.holidays is not None:
seasonal_features.append(self.make_holiday_features(df['ds']))
return pd.concat(seasonal_features, axis=1)
def parse_seasonality_args(self, name, arg, auto_disable, default_order):
"""Get number of fourier components for built-in seasonalities.
Parameters
----------
name: string name of the seasonality component.
arg: 'auto', True, False, or number of fourier components as provided.
auto_disable: bool if seasonality should be disabled when 'auto'.
default_order: int default fourier order
Returns
-------
Number of fourier components, or 0 for disabled.
"""
if arg == 'auto':
fourier_order = 0
if name in self.seasonalities:
logger.info(
'Found custom seasonality named "{name}", '
'disabling built-in {name} seasonality.'.format(name=name)
)
elif auto_disable:
logger.info(
'Disabling {name} seasonality. Run prophet with '
'{name}_seasonality=True to override this.'.format(
name=name)
)
else:
fourier_order = default_order
elif arg is True:
fourier_order = default_order
elif arg is False:
fourier_order = 0
else:
fourier_order = int(arg)
return fourier_order
def set_auto_seasonalities(self):
"""Set seasonalities that were left on auto.
Turns on yearly seasonality if there is >=2 years of history.
Turns on weekly seasonality if there is >=2 weeks of history, and the
spacing between dates in the history is <7 days.
Turns on daily seasonality if there is >=2 days of history, and the
spacing between dates in the history is <1 day.
"""
first = self.history['ds'].min()
last = self.history['ds'].max()
if self.yearly_seasonality == 'auto':
if last - first < pd.Timedelta(days=730):
self.yearly_seasonality = 0
logger.info('Disabling yearly seasonality. Run prophet with '
'yearly_seasonality=True to override this.')
else:
self.yearly_seasonality = 10
elif self.yearly_seasonality is True:
self.yearly_seasonality = 10
if self.weekly_seasonality == 'auto':
dt = self.history['ds'].diff()
min_dt = dt.iloc[dt.nonzero()[0]].min()
if ((last - first < pd.Timedelta(weeks=2)) or
(min_dt >= pd.Timedelta(weeks=1))):
self.weekly_seasonality = 0
logger.info('Disabling weekly seasonality. Run prophet with '
'weekly_seasonality=True to override this.')
else:
self.weekly_seasonality = 3
elif self.weekly_seasonality is True:
self.weekly_seasonality = 3
if self.daily_seasonality == 'auto':
# disabled by default but if the average difference is <1 day
# then we assume there is intra-day modeling
dt = self.history['ds'].diff()
min_dt = dt.iloc[dt.nonzero()[0]].min()
if (min_dt< pd.Timedelta(days=1)):
self.daily_seasonality = 4
logger.info('Enabling daily seasonality. Run prophet with '
'daily_seasonality=False to override this.')
else:
self.daily_seasonality = 0
elif self.daily_seasonality is True:
self.daily_seasonality = 4
dt = self.history['ds'].diff()
min_dt = dt.iloc[dt.nonzero()[0]].min()
# Yearly seasonality
yearly_disable = last - first < pd.Timedelta(days=730)
fourier_order = self.parse_seasonality_args(
'yearly', self.yearly_seasonality, yearly_disable, 10)
if fourier_order > 0:
self.seasonalities['yearly'] = (365.25, fourier_order)
# Weekly seasonality
weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
(min_dt >= pd.Timedelta(weeks=1)))
fourier_order = self.parse_seasonality_args(
'weekly', self.weekly_seasonality, weekly_disable, 3)
if fourier_order > 0:
self.seasonalities['weekly'] = (7, fourier_order)
# Daily seasonality
daily_disable = ((last - first < pd.Timedelta(days=2)) or
(min_dt >= pd.Timedelta(days=1)))
fourier_order = self.parse_seasonality_args(
'daily', self.daily_seasonality, daily_disable, 4)
if fourier_order > 0:
self.seasonalities['daily'] = (1, fourier_order)
@staticmethod
def linear_growth_init(df):

View file

@ -252,43 +252,51 @@ class TestProphet(TestCase):
self.assertEqual(future.iloc[i]['ds'], correct[i])
def test_auto_weekly_seasonality(self):
# Should be True
# Should be enabled
N = 15
train = DATA.head(N)
m = Prophet()
self.assertEqual(m.weekly_seasonality, 'auto')
m.fit(train)
self.assertEqual(m.weekly_seasonality, True)
# Should be False due to too short history
self.assertIn('weekly', m.seasonalities)
self.assertEqual(m.seasonalities['weekly'], (7, 3))
# Should be disabled due to too short history
N = 9
train = DATA.head(N)
m = Prophet()
m.fit(train)
self.assertEqual(m.weekly_seasonality, False)
self.assertNotIn('weekly', m.seasonalities)
m = Prophet(weekly_seasonality=True)
m.fit(train)
self.assertEqual(m.weekly_seasonality, True)
self.assertIn('weekly', m.seasonalities)
# Should be False due to weekly spacing
train = DATA.iloc[::7, :]
m = Prophet()
m.fit(train)
self.assertEqual(m.weekly_seasonality, False)
self.assertNotIn('weekly', m.seasonalities)
m = Prophet(weekly_seasonality=2)
m.fit(DATA)
self.assertEqual(m.seasonalities['weekly'], (7, 2))
def test_auto_yearly_seasonality(self):
# Should be True
# Should be enabled
m = Prophet()
self.assertEqual(m.yearly_seasonality, 'auto')
m.fit(DATA)
self.assertEqual(m.yearly_seasonality, True)
# Should be False due to too short history
self.assertIn('yearly', m.seasonalities)
self.assertEqual(m.seasonalities['yearly'], (365.25, 10))
# Should be disabled due to too short history
N = 240
train = DATA.head(N)
m = Prophet()
m.fit(train)
self.assertEqual(m.yearly_seasonality, False)
self.assertNotIn('yearly', m.seasonalities)
m = Prophet(yearly_seasonality=True)
m.fit(train)
self.assertEqual(m.yearly_seasonality, True)
self.assertIn('yearly', m.seasonalities)
m = Prophet(yearly_seasonality=7)
m.fit(DATA)
self.assertEqual(m.seasonalities['yearly'], (365.25, 7))
DATA = pd.read_csv(StringIO("""