mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Move plotting into a separate file (Py)
This commit is contained in:
parent
0c30f6efcf
commit
7c95c62cf0
2 changed files with 396 additions and 238 deletions
|
|
@ -14,24 +14,28 @@ from collections import defaultdict
|
|||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.dates import MonthLocator, num2date
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
except ImportError:
|
||||
logger.exception('Importing matplotlib failed. Plotting will not work.')
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# fb-block 1 start
|
||||
from fbprophet.models import prophet_stan_models
|
||||
from fbprophet.plot import (
|
||||
plot,
|
||||
plot_components,
|
||||
plot_forecast_component,
|
||||
seasonality_plot_df,
|
||||
plot_weekly,
|
||||
plot_yearly,
|
||||
plot_seasonality,
|
||||
)
|
||||
# fb-block 1 end
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
warnings.filterwarnings("default", category=DeprecationWarning)
|
||||
|
||||
try:
|
||||
import pystan # noqa F401
|
||||
except ImportError:
|
||||
|
|
@ -1251,26 +1255,10 @@ class Prophet(object):
|
|||
-------
|
||||
A matplotlib figure.
|
||||
"""
|
||||
if ax is None:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
fcst_t = fcst['ds'].dt.to_pydatetime()
|
||||
ax.plot(self.history['ds'].dt.to_pydatetime(), self.history['y'], 'k.')
|
||||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
|
||||
color='#0072B2', alpha=0.2)
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
return plot(
|
||||
m=self, fcst=fcst, ax=ax, uncertainty=uncertainty,
|
||||
plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel,
|
||||
)
|
||||
|
||||
def plot_components(self, fcst, uncertainty=True, plot_cap=True,
|
||||
weekly_start=0, yearly_start=0):
|
||||
|
|
@ -1296,225 +1284,60 @@ class Prophet(object):
|
|||
-------
|
||||
A matplotlib figure.
|
||||
"""
|
||||
# Identify components to be plotted
|
||||
components = ['trend']
|
||||
if self.holidays is not None and 'holidays' in fcst:
|
||||
components.append('holidays')
|
||||
components.extend([name for name in self.seasonalities
|
||||
if name in fcst])
|
||||
if len(self.extra_regressors) > 0 and 'extra_regressors' in fcst:
|
||||
components.append('extra_regressors')
|
||||
npanel = len(components)
|
||||
|
||||
fig, axes = plt.subplots(npanel, 1, facecolor='w',
|
||||
figsize=(9, 3 * npanel))
|
||||
|
||||
if npanel == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, plot in zip(axes, components):
|
||||
if plot == 'trend':
|
||||
self.plot_forecast_component(
|
||||
fcst, 'trend', ax, uncertainty, plot_cap)
|
||||
elif plot == 'holidays':
|
||||
self.plot_forecast_component(
|
||||
fcst, 'holidays', ax, uncertainty, False)
|
||||
elif plot == 'weekly':
|
||||
self.plot_weekly(
|
||||
ax=ax, uncertainty=uncertainty, weekly_start=weekly_start)
|
||||
elif plot == 'yearly':
|
||||
self.plot_yearly(
|
||||
ax=ax, uncertainty=uncertainty, yearly_start=yearly_start)
|
||||
elif plot == 'extra_regressors':
|
||||
self.plot_forecast_component(
|
||||
fcst, 'extra_regressors', ax, uncertainty, False)
|
||||
else:
|
||||
self.plot_seasonality(
|
||||
name=plot, ax=ax, uncertainty=uncertainty)
|
||||
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
return plot_components(
|
||||
m=self, fcst=fcst, uncertainty=uncertainty, plot_cap=plot_cap,
|
||||
weekly_start=weekly_start, yearly_start=yearly_start,
|
||||
)
|
||||
|
||||
def plot_forecast_component(
|
||||
self, fcst, name, ax=None, uncertainty=True, plot_cap=False):
|
||||
"""Plot a particular component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fcst: pd.DataFrame output of self.predict.
|
||||
name: Name of the component to plot.
|
||||
ax: Optional matplotlib Axes to plot on.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
plot_cap: Optional boolean indicating if the capacity should be shown
|
||||
in the figure, if available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
fcst_t = fcst['ds'].dt.to_pydatetime()
|
||||
artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
|
||||
if self.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
|
||||
color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xlabel('ds')
|
||||
ax.set_ylabel(name)
|
||||
return artists
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.plot.plot_forecast_component. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return plot_forecast_component(
|
||||
self, fcst=fcst, name=name, ax=ax, uncertainty=uncertainty,
|
||||
plot_cap=plot_cap,
|
||||
)
|
||||
|
||||
def seasonality_plot_df(self, ds):
|
||||
"""Prepare dataframe for plotting seasonal components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ds: List of dates for column ds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dataframe with seasonal components on ds.
|
||||
"""
|
||||
df_dict = {'ds': ds, 'cap': 1., 'floor': 0.}
|
||||
for name in self.extra_regressors:
|
||||
df_dict[name] = 0.
|
||||
df = pd.DataFrame(df_dict)
|
||||
df = self.setup_dataframe(df)
|
||||
return df
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.plot.seasonality_plot_df. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return seasonality_plot_df(self, ds=ds)
|
||||
|
||||
def plot_weekly(self, ax=None, uncertainty=True, weekly_start=0):
|
||||
"""Plot the weekly component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if this
|
||||
is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
weekly_start: Optional int specifying the start day of the weekly
|
||||
seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
|
||||
by 1 day to Monday, and so on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
||||
days = (pd.date_range(start='2017-01-01', periods=7) +
|
||||
pd.Timedelta(days=weekly_start))
|
||||
df_w = self.seasonality_plot_df(days)
|
||||
seas = self.predict_seasonal_components(df_w)
|
||||
days = days.weekday_name
|
||||
artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
|
||||
c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(range(len(days)),
|
||||
seas['weekly_lower'], seas['weekly_upper'],
|
||||
color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xticks(range(len(days)))
|
||||
ax.set_xticklabels(days)
|
||||
ax.set_xlabel('Day of week')
|
||||
ax.set_ylabel('weekly')
|
||||
return artists
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.plot.plot_weekly. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return plot_weekly(
|
||||
self, ax=ax, uncertainty=uncertainty, weekly_start=weekly_start,
|
||||
)
|
||||
|
||||
def plot_yearly(self, ax=None, uncertainty=True, yearly_start=0):
|
||||
"""Plot the yearly component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if
|
||||
this is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
yearly_start: Optional int specifying the start day of the yearly
|
||||
seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
|
||||
by 1 day to Jan 2, and so on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
|
||||
days = (pd.date_range(start='2017-01-01', periods=365) +
|
||||
pd.Timedelta(days=yearly_start))
|
||||
df_y = self.seasonality_plot_df(days)
|
||||
seas = self.predict_seasonal_components(df_y)
|
||||
artists += ax.plot(
|
||||
df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
|
||||
seas['yearly_upper'], color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
|
||||
ax.xaxis.set_major_formatter(FuncFormatter(
|
||||
lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
|
||||
ax.xaxis.set_major_locator(months)
|
||||
ax.set_xlabel('Day of year')
|
||||
ax.set_ylabel('yearly')
|
||||
return artists
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.plot.plot_yearly. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return plot_yearly(
|
||||
self, ax=ax, uncertainty=uncertainty, yearly_start=yearly_start,
|
||||
)
|
||||
|
||||
def plot_seasonality(self, name, ax=None, uncertainty=True):
|
||||
"""Plot a custom seasonal component.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: Seasonality name, like 'daily', 'weekly'.
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if
|
||||
this is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute seasonality from Jan 1 through a single period.
|
||||
start = pd.to_datetime('2017-01-01 0000')
|
||||
period = self.seasonalities[name]['period']
|
||||
end = start + pd.Timedelta(days=period)
|
||||
plot_points = 200
|
||||
days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
|
||||
df_y = self.seasonality_plot_df(days)
|
||||
seas = self.predict_seasonal_components(df_y)
|
||||
artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
|
||||
c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
|
||||
seas[name + '_upper'], color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
|
||||
).to_pydatetime()
|
||||
ax.set_xticks(xticks)
|
||||
if period <= 2:
|
||||
fmt_str = '{dt:%T}'
|
||||
elif period < 14:
|
||||
fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
|
||||
else:
|
||||
fmt_str = '{dt:%m}/{dt:%d}'
|
||||
ax.xaxis.set_major_formatter(FuncFormatter(
|
||||
lambda x, pos=None: fmt_str.format(dt=num2date(x))))
|
||||
ax.set_xlabel('ds')
|
||||
ax.set_ylabel(name)
|
||||
return artists
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.plot.plot_seasonality. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return plot_seasonality(
|
||||
self, name=name, ax=ax, uncertainty=uncertainty,
|
||||
)
|
||||
|
||||
def copy(self, cutoff=None):
|
||||
"""Copy Prophet object
|
||||
|
|
|
|||
335
python/fbprophet/plot.py
Normal file
335
python/fbprophet/plot.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree. An additional grant
|
||||
# of patent rights can be found in the PATENTS file in the same directory.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.dates import MonthLocator, num2date
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
except ImportError:
|
||||
logger.error('Importing matplotlib failed. Plotting will not work.')
|
||||
|
||||
|
||||
def plot(
|
||||
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
|
||||
):
|
||||
"""Plot the Prophet forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
fcst: pd.DataFrame output of m.predict.
|
||||
ax: Optional matplotlib axes on which to plot.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
plot_cap: Optional boolean indicating if the capacity should be shown
|
||||
in the figure, if available.
|
||||
xlabel: Optional label name on X-axis
|
||||
ylabel: Optional label name on Y-axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
A matplotlib figure.
|
||||
"""
|
||||
if ax is None:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
fcst_t = fcst['ds'].dt.to_pydatetime()
|
||||
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
|
||||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
|
||||
if m.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
|
||||
color='#0072B2', alpha=0.2)
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_components(
|
||||
m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0,
|
||||
):
|
||||
"""Plot the Prophet forecast components.
|
||||
|
||||
Will plot whichever are available of: trend, holidays, weekly
|
||||
seasonality, and yearly seasonality.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
fcst: pd.DataFrame output of m.predict.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
plot_cap: Optional boolean indicating if the capacity should be shown
|
||||
in the figure, if available.
|
||||
weekly_start: Optional int specifying the start day of the weekly
|
||||
seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
|
||||
by 1 day to Monday, and so on.
|
||||
yearly_start: Optional int specifying the start day of the yearly
|
||||
seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
|
||||
by 1 day to Jan 2, and so on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A matplotlib figure.
|
||||
"""
|
||||
# Identify components to be plotted
|
||||
components = ['trend']
|
||||
if m.holidays is not None and 'holidays' in fcst:
|
||||
components.append('holidays')
|
||||
components.extend([name for name in m.seasonalities
|
||||
if name in fcst])
|
||||
if len(m.extra_regressors) > 0 and 'extra_regressors' in fcst:
|
||||
components.append('extra_regressors')
|
||||
npanel = len(components)
|
||||
|
||||
fig, axes = plt.subplots(npanel, 1, facecolor='w',
|
||||
figsize=(9, 3 * npanel))
|
||||
|
||||
if npanel == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, plot_name in zip(axes, components):
|
||||
if plot_name == 'trend':
|
||||
plot_forecast_component(
|
||||
m=m, fcst=fcst, name='trend', ax=ax, uncertainty=uncertainty,
|
||||
plot_cap=plot_cap,
|
||||
)
|
||||
elif plot_name == 'holidays':
|
||||
plot_forecast_component(
|
||||
m=m, fcst=fcst, name='holidays', ax=ax,
|
||||
uncertainty=uncertainty, plot_cap=False,
|
||||
)
|
||||
elif plot_name == 'weekly':
|
||||
plot_weekly(
|
||||
m=m, ax=ax, uncertainty=uncertainty, weekly_start=weekly_start,
|
||||
)
|
||||
elif plot_name == 'yearly':
|
||||
plot_yearly(
|
||||
m=m, ax=ax, uncertainty=uncertainty, yearly_start=yearly_start,
|
||||
)
|
||||
elif plot_name == 'extra_regressors':
|
||||
plot_forecast_component(
|
||||
m=m, fcst=fcst, name='extra_regressors', ax=ax,
|
||||
uncertainty=uncertainty, plot_cap=False,
|
||||
)
|
||||
else:
|
||||
plot_seasonality(
|
||||
m=m, name=plot_name, ax=ax, uncertainty=uncertainty,
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_forecast_component(
|
||||
m, fcst, name, ax=None, uncertainty=True, plot_cap=False,
|
||||
):
|
||||
"""Plot a particular component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
fcst: pd.DataFrame output of m.predict.
|
||||
name: Name of the component to plot.
|
||||
ax: Optional matplotlib Axes to plot on.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
plot_cap: Optional boolean indicating if the capacity should be shown
|
||||
in the figure, if available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
fcst_t = fcst['ds'].dt.to_pydatetime()
|
||||
artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
|
||||
if m.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
|
||||
color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xlabel('ds')
|
||||
ax.set_ylabel(name)
|
||||
return artists
|
||||
|
||||
|
||||
def seasonality_plot_df(m, ds):
|
||||
"""Prepare dataframe for plotting seasonal components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
ds: List of dates for column ds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dataframe with seasonal components on ds.
|
||||
"""
|
||||
df_dict = {'ds': ds, 'cap': 1., 'floor': 0.}
|
||||
for name in m.extra_regressors:
|
||||
df_dict[name] = 0.
|
||||
df = pd.DataFrame(df_dict)
|
||||
df = m.setup_dataframe(df)
|
||||
return df
|
||||
|
||||
|
||||
def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
|
||||
"""Plot the weekly component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if this
|
||||
is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
weekly_start: Optional int specifying the start day of the weekly
|
||||
seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
|
||||
by 1 day to Monday, and so on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
||||
days = (pd.date_range(start='2017-01-01', periods=7) +
|
||||
pd.Timedelta(days=weekly_start))
|
||||
df_w = seasonality_plot_df(m, days)
|
||||
seas = m.predict_seasonal_components(df_w)
|
||||
days = days.weekday_name
|
||||
artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
|
||||
c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(range(len(days)),
|
||||
seas['weekly_lower'], seas['weekly_upper'],
|
||||
color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xticks(range(len(days)))
|
||||
ax.set_xticklabels(days)
|
||||
ax.set_xlabel('Day of week')
|
||||
ax.set_ylabel('weekly')
|
||||
return artists
|
||||
|
||||
|
||||
def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
|
||||
"""Plot the yearly component of the forecast.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if
|
||||
this is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
yearly_start: Optional int specifying the start day of the yearly
|
||||
seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
|
||||
by 1 day to Jan 2, and so on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
|
||||
days = (pd.date_range(start='2017-01-01', periods=365) +
|
||||
pd.Timedelta(days=yearly_start))
|
||||
df_y = seasonality_plot_df(m, days)
|
||||
seas = m.predict_seasonal_components(df_y)
|
||||
artists += ax.plot(
|
||||
df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
|
||||
seas['yearly_upper'], color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
|
||||
ax.xaxis.set_major_formatter(FuncFormatter(
|
||||
lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
|
||||
ax.xaxis.set_major_locator(months)
|
||||
ax.set_xlabel('Day of year')
|
||||
ax.set_ylabel('yearly')
|
||||
return artists
|
||||
|
||||
|
||||
def plot_seasonality(m, name, ax=None, uncertainty=True):
|
||||
"""Plot a custom seasonal component.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
name: Seasonality name, like 'daily', 'weekly'.
|
||||
ax: Optional matplotlib Axes to plot on. One will be created if
|
||||
this is not provided.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a list of matplotlib artists
|
||||
"""
|
||||
artists = []
|
||||
if not ax:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
# Compute seasonality from Jan 1 through a single period.
|
||||
start = pd.to_datetime('2017-01-01 0000')
|
||||
period = m.seasonalities[name]['period']
|
||||
end = start + pd.Timedelta(days=period)
|
||||
plot_points = 200
|
||||
days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
|
||||
df_y = seasonality_plot_df(m, days)
|
||||
seas = m.predict_seasonal_components(df_y)
|
||||
artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
|
||||
c='#0072B2')
|
||||
if uncertainty:
|
||||
artists += [ax.fill_between(
|
||||
df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
|
||||
seas[name + '_upper'], color='#0072B2', alpha=0.2)]
|
||||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
|
||||
).to_pydatetime()
|
||||
ax.set_xticks(xticks)
|
||||
if period <= 2:
|
||||
fmt_str = '{dt:%T}'
|
||||
elif period < 14:
|
||||
fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
|
||||
else:
|
||||
fmt_str = '{dt:%m}/{dt:%d}'
|
||||
ax.xaxis.set_major_formatter(FuncFormatter(
|
||||
lambda x, pos=None: fmt_str.format(dt=num2date(x))))
|
||||
ax.set_xlabel('ds')
|
||||
ax.set_ylabel(name)
|
||||
return artists
|
||||
Loading…
Reference in a new issue