mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-25 22:26:34 +00:00
Move built-in country holidays to a method
This commit is contained in:
parent
ebf7b3da0a
commit
92f955d25a
2 changed files with 107 additions and 57 deletions
|
|
@ -68,7 +68,6 @@ class Prophet(object):
|
|||
lower_window=-2 will include 2 days prior to the date as holidays. Also
|
||||
optionally can have a column prior_scale specifying the prior scale for
|
||||
that holiday.
|
||||
append_holidays: country name or abbreviation; must be string
|
||||
seasonality_mode: 'additive' (default) or 'multiplicative'.
|
||||
seasonality_prior_scale: Parameter modulating the strength of the
|
||||
seasonality model. Larger values allow the model to fit larger seasonal
|
||||
|
|
@ -101,7 +100,6 @@ class Prophet(object):
|
|||
weekly_seasonality='auto',
|
||||
daily_seasonality='auto',
|
||||
holidays=None,
|
||||
append_holidays=None,
|
||||
seasonality_mode='additive',
|
||||
seasonality_prior_scale=10.0,
|
||||
holidays_prior_scale=10.0,
|
||||
|
|
@ -136,13 +134,6 @@ class Prophet(object):
|
|||
holidays['ds'] = pd.to_datetime(holidays['ds'])
|
||||
self.holidays = holidays
|
||||
|
||||
if append_holidays is not None:
|
||||
if not (
|
||||
isinstance(append_holidays, str)
|
||||
):
|
||||
raise ValueError("append_holidays must be a string")
|
||||
self.append_holidays = append_holidays
|
||||
|
||||
self.seasonality_mode = seasonality_mode
|
||||
self.seasonality_prior_scale = float(seasonality_prior_scale)
|
||||
self.changepoint_prior_scale = float(changepoint_prior_scale)
|
||||
|
|
@ -152,7 +143,7 @@ class Prophet(object):
|
|||
self.interval_width = interval_width
|
||||
self.uncertainty_samples = uncertainty_samples
|
||||
|
||||
# Set during fitting
|
||||
# Set during fitting or by other methods
|
||||
self.start = None
|
||||
self.y_scale = None
|
||||
self.logistic_floor = False
|
||||
|
|
@ -160,6 +151,7 @@ class Prophet(object):
|
|||
self.changepoints_t = None
|
||||
self.seasonalities = {}
|
||||
self.extra_regressors = OrderedDict({})
|
||||
self.country_holidays = None
|
||||
self.stan_fit = None
|
||||
self.params = {}
|
||||
self.history = None
|
||||
|
|
@ -224,10 +216,10 @@ class Prophet(object):
|
|||
name in self.holidays['holiday'].unique()):
|
||||
raise ValueError(
|
||||
'Name "{}" already used for a holiday.'.format(name))
|
||||
if (check_holidays and self.append_holidays is not None and
|
||||
name in get_holiday_names(self.append_holidays)):
|
||||
if (check_holidays and self.country_holidays is not None and
|
||||
name in get_holiday_names(self.country_holidays)):
|
||||
raise ValueError(
|
||||
'Name "{}" is a holiday name in {}.'.format(name, self.append_holidays))
|
||||
'Name "{}" is a holiday name in {}.'.format(name, self.country_holidays))
|
||||
if check_seasonalities and name in self.seasonalities:
|
||||
raise ValueError(
|
||||
'Name "{}" already used for a seasonality.'.format(name))
|
||||
|
|
@ -430,12 +422,57 @@ class Prophet(object):
|
|||
]
|
||||
return pd.DataFrame(features, columns=columns)
|
||||
|
||||
def make_holiday_features(self, dates):
|
||||
def construct_holiday_dataframe(self, dates):
|
||||
"""Construct a dataframe of holiday dates.
|
||||
|
||||
Will combine self.holidays with the built-in country holidays
|
||||
corresponding to input dates, if self.country_holidays is set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates: pd.Series containing timestamps used for computing seasonality.
|
||||
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
all_holidays = pd.DataFrame()
|
||||
if self.holidays is not None:
|
||||
all_holidays = pd.concat((all_holidays, self.holidays))
|
||||
if self.country_holidays is not None:
|
||||
year_list = list({x.year for x in dates})
|
||||
country_holidays_df = make_holidays_df(
|
||||
year_list=year_list, country=self.country_holidays
|
||||
)
|
||||
all_holidays = pd.concat((all_holidays, country_holidays_df), sort=False)
|
||||
all_holidays.reset_index(drop=True, inplace=True)
|
||||
# If the model has already been fit with a certain set of holidays,
|
||||
# make sure we are using those same ones.
|
||||
if self.train_holiday_names is not None:
|
||||
# Remove holiday names didn't show up in fit
|
||||
index_to_drop = all_holidays.index[
|
||||
np.logical_not(
|
||||
all_holidays.holiday.isin(self.train_holiday_names)
|
||||
)
|
||||
]
|
||||
all_holidays = all_holidays.drop(index_to_drop)
|
||||
# Add holiday names in fit but not in predict with ds as NA
|
||||
holidays_to_add = pd.DataFrame({
|
||||
'holiday': self.train_holiday_names[
|
||||
np.logical_not(self.train_holiday_names.isin(all_holidays.holiday))
|
||||
]
|
||||
})
|
||||
all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
|
||||
all_holidays.reset_index(drop=True, inplace=True)
|
||||
return all_holidays
|
||||
|
||||
def make_holiday_features(self, dates, holidays):
|
||||
"""Construct a dataframe of holiday features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates: pd.Series containing timestamps used for computing seasonality.
|
||||
holidays: pd.Dataframe containing holidays, as returned by
|
||||
construct_holiday_dataframe.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -443,32 +480,6 @@ class Prophet(object):
|
|||
prior_scale_list: List of prior scales for each holiday column.
|
||||
holiday_names: List of names of holidays
|
||||
"""
|
||||
# Concatenate holidays and append_holidays
|
||||
all_holidays = self.holidays
|
||||
if self.append_holidays is not None:
|
||||
year_list = list({x.year for x in dates})
|
||||
append_holidays_df = make_holidays_df(
|
||||
year_list=year_list,
|
||||
country=self.append_holidays)
|
||||
all_holidays = pd.concat((all_holidays, append_holidays_df), sort=False)
|
||||
all_holidays.reset_index(drop=True, inplace=True)
|
||||
# Make fit and predict holidays components match
|
||||
if self.train_holiday_names is not None:
|
||||
train_holidays = self.train_holiday_names
|
||||
# Remove holiday names didn't show up in fit
|
||||
index_to_drop = all_holidays.index[
|
||||
np.logical_not(
|
||||
all_holidays.holiday.isin(train_holidays))]
|
||||
all_holidays = all_holidays.drop(index_to_drop)
|
||||
# Add holiday names show up in fit but not in predict with ds as NA
|
||||
holidays_to_add = pd.DataFrame(
|
||||
{'holiday':
|
||||
train_holidays[
|
||||
np.logical_not(
|
||||
train_holidays.isin(all_holidays.holiday))]})
|
||||
all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
|
||||
all_holidays.reset_index(drop=True, inplace=True)
|
||||
|
||||
# Holds columns of our future matrix.
|
||||
expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0]))
|
||||
prior_scales = {}
|
||||
|
|
@ -476,7 +487,7 @@ class Prophet(object):
|
|||
# Strip to just dates.
|
||||
row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date()))
|
||||
|
||||
for _ix, row in all_holidays.iterrows():
|
||||
for _ix, row in holidays.iterrows():
|
||||
dt = row.ds.date()
|
||||
try:
|
||||
lw = int(row.get('lower_window', 0))
|
||||
|
|
@ -635,6 +646,44 @@ class Prophet(object):
|
|||
}
|
||||
return self
|
||||
|
||||
def add_country_holidays(self, country_name):
|
||||
"""Add in built-in holidays for the specified country.
|
||||
|
||||
These holidays will be included in addition to any specified on model
|
||||
initialization.
|
||||
|
||||
Holidays will be calculated for arbitrary date ranges in the history
|
||||
and future. See the online documentation for the list of countries with
|
||||
built-in holidays.
|
||||
|
||||
Built-in country holidays can only be set for a single country.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
country_name: Name of the country, like 'UnitedStates' or 'US'
|
||||
|
||||
Returns
|
||||
-------
|
||||
The prophet object.
|
||||
"""
|
||||
if self.history is not None:
|
||||
raise Exception(
|
||||
"Country holidays must be added prior to model fitting."
|
||||
)
|
||||
# Validate names.
|
||||
for name in get_holiday_names(country_name):
|
||||
# Allow merging with existing holidays
|
||||
self.validate_column_name(name, check_holidays=False)
|
||||
# Set the holidays.
|
||||
if self.country_holidays is not None:
|
||||
logger.warning(
|
||||
'Changing country holidays from {} to {}'.format(
|
||||
self.country_holidays, country_name
|
||||
)
|
||||
)
|
||||
self.country_holidays = country_name
|
||||
return self
|
||||
|
||||
def make_all_seasonality_features(self, df):
|
||||
"""Dataframe with seasonality features.
|
||||
|
||||
|
|
@ -672,9 +721,10 @@ class Prophet(object):
|
|||
modes[props['mode']].append(name)
|
||||
|
||||
# Holiday features
|
||||
if self.holidays is not None or self.append_holidays is not None:
|
||||
holidays = self.construct_holiday_dataframe(df['ds'])
|
||||
if len(holidays) > 0:
|
||||
features, holiday_priors, holiday_names = (
|
||||
self.make_holiday_features(df['ds'])
|
||||
self.make_holiday_features(df['ds'], holidays)
|
||||
)
|
||||
seasonal_features.append(features)
|
||||
prior_scales.extend(holiday_priors)
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class TestProphet(TestCase):
|
|||
df = pd.DataFrame({
|
||||
'ds': pd.date_range('2016-12-20', '2016-12-31')
|
||||
})
|
||||
feats, priors, names = model.make_holiday_features(df['ds'])
|
||||
feats, priors, names = model.make_holiday_features(df['ds'], model.holidays)
|
||||
# 11 columns generated even though only 8 overlap
|
||||
self.assertEqual(feats.shape, (df.shape[0], 2))
|
||||
self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
|
||||
|
|
@ -295,7 +295,7 @@ class TestProphet(TestCase):
|
|||
'upper_window': [10],
|
||||
})
|
||||
m = Prophet(holidays=holidays)
|
||||
feats, priors, names = m.make_holiday_features(df['ds'])
|
||||
feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
|
||||
# 12 columns generated even though only 8 overlap
|
||||
self.assertEqual(feats.shape, (df.shape[0], 12))
|
||||
self.assertEqual(priors, list(10. * np.ones(12)))
|
||||
|
|
@ -309,7 +309,7 @@ class TestProphet(TestCase):
|
|||
'prior_scale': [5., 5.],
|
||||
})
|
||||
m = Prophet(holidays=holidays)
|
||||
feats, priors, names = m.make_holiday_features(df['ds'])
|
||||
feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
|
||||
self.assertEqual(priors, [5., 5.])
|
||||
self.assertEqual(names, ['xmas'])
|
||||
# 2 different priors
|
||||
|
|
@ -322,7 +322,7 @@ class TestProphet(TestCase):
|
|||
})
|
||||
holidays2 = pd.concat((holidays, holidays2))
|
||||
m = Prophet(holidays=holidays2)
|
||||
feats, priors, names = m.make_holiday_features(df['ds'])
|
||||
feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
|
||||
pn = zip(priors, [s.split('_delim_')[0] for s in feats.columns])
|
||||
for t in pn:
|
||||
self.assertIn(t, [(8., 'seans-bday'), (5., 'xmas')])
|
||||
|
|
@ -335,7 +335,7 @@ class TestProphet(TestCase):
|
|||
holidays2 = pd.concat((holidays, holidays2))
|
||||
feats, priors, names = Prophet(
|
||||
holidays=holidays2, holidays_prior_scale=4
|
||||
).make_holiday_features(df['ds'])
|
||||
).make_holiday_features(df['ds'], holidays2)
|
||||
self.assertEqual(set(priors), {4., 5.})
|
||||
# Check incompatible priors
|
||||
holidays = pd.DataFrame({
|
||||
|
|
@ -346,7 +346,7 @@ class TestProphet(TestCase):
|
|||
'prior_scale': [5., 6.],
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
Prophet(holidays=holidays).make_holiday_features(df['ds'])
|
||||
Prophet(holidays=holidays).make_holiday_features(df['ds'], holidays)
|
||||
|
||||
def test_fit_with_holidays(self):
|
||||
holidays = pd.DataFrame({
|
||||
|
|
@ -358,28 +358,28 @@ class TestProphet(TestCase):
|
|||
model = Prophet(holidays=holidays, uncertainty_samples=0)
|
||||
model.fit(DATA).predict()
|
||||
|
||||
def test_fit_predict_with_append_holidays(self):
|
||||
def test_fit_predict_with_country_holidays(self):
|
||||
holidays = pd.DataFrame({
|
||||
'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
|
||||
'holiday': ['seans-bday'] * 2,
|
||||
'lower_window': [0] * 2,
|
||||
'upper_window': [1] * 2,
|
||||
})
|
||||
append_holidays = 'US'
|
||||
# Test with holidays and append_holidays
|
||||
model = Prophet(holidays=holidays,
|
||||
append_holidays=append_holidays,
|
||||
uncertainty_samples=0)
|
||||
# Test with holidays and country_holidays
|
||||
model = Prophet(holidays=holidays, uncertainty_samples=0)
|
||||
model.add_country_holidays(country_name='US')
|
||||
model.fit(DATA).predict()
|
||||
# There are training holidays missing in the test set
|
||||
train = DATA.head(154)
|
||||
future = DATA.tail(355)
|
||||
model = Prophet(append_holidays=append_holidays, uncertainty_samples=0)
|
||||
model = Prophet(uncertainty_samples=0)
|
||||
model.add_country_holidays(country_name='US')
|
||||
model.fit(train).predict(future)
|
||||
# There are test holidays missing in the training set
|
||||
train = DATA.tail(355)
|
||||
future = DATA2
|
||||
model = Prophet(append_holidays=append_holidays, uncertainty_samples=0)
|
||||
model = Prophet(uncertainty_samples=0)
|
||||
model.add_country_holidays(country_name='US')
|
||||
model.fit(train).predict(future)
|
||||
|
||||
def test_make_future_dataframe(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue