mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-16 21:00:16 +00:00
Added a interactive Plotly plot of the forecast (#915)
This commit is contained in:
parent
a31a93480a
commit
c756d20100
1 changed files with 155 additions and 0 deletions
|
|
@ -25,6 +25,12 @@ try:
|
|||
except ImportError:
|
||||
logger.error('Importing matplotlib failed. Plotting will not work.')
|
||||
|
||||
try:
|
||||
import plotly.graph_objs as go
|
||||
import plotly.offline
|
||||
except ImportError:
|
||||
logger.error('Importing plotly failed. Interactive plots will not work.')
|
||||
|
||||
|
||||
def plot(
|
||||
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
|
||||
|
|
@ -495,3 +501,152 @@ def plot_cross_validation_metric(
|
|||
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
|
||||
ax.set_ylabel(metric)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_plotly(m, fcst, uncertainty=True, plot_cap=True, trend=False, changepoints=False,
|
||||
changepoints_threshold=0.01, xlabel='ds', ylabel='y'):
|
||||
"""Plot the Prophet forecast with Plotly offline.
|
||||
|
||||
Requires plotly.offline.init_notebook_mode() to have been run,
|
||||
see https://plot.ly/python/offline/ for details
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
fcst: pd.DataFrame output of m.predict.
|
||||
uncertainty: Optional boolean to plot uncertainty intervals.
|
||||
plot_cap: Optional boolean indicating if the capacity should be shown
|
||||
in the figure, if available.
|
||||
trend: Optional boolean to plot trend
|
||||
changepoints: Optional boolean to plot changepoints
|
||||
changepoints_threshold: Threshold on trend change magnitude for significance.
|
||||
xlabel: Optional label name on X-axis
|
||||
ylabel: Optional label name on Y-axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Plotly plot.
|
||||
"""
|
||||
prediction_color = '#0072B2'
|
||||
error_color = 'rgba(0, 114, 178, 0.2)' # '#0072B2' with 0.2 opacity
|
||||
actual_color = 'black'
|
||||
cap_color = 'black'
|
||||
trend_color = '#B23B00'
|
||||
line_width = 2
|
||||
marker_size = 4
|
||||
|
||||
data = []
|
||||
# Add actual
|
||||
data.append(go.Scatter(
|
||||
name='Actual',
|
||||
x=m.history['ds'],
|
||||
y=m.history['y'],
|
||||
marker=dict(color=actual_color, size=marker_size),
|
||||
mode='markers'
|
||||
))
|
||||
# Add lower bound
|
||||
if uncertainty:
|
||||
data.append(go.Scatter(
|
||||
x=fcst['ds'],
|
||||
y=fcst['yhat_lower'],
|
||||
mode='lines',
|
||||
line=dict(width=0),
|
||||
hoverinfo='skip'
|
||||
))
|
||||
# Add prediction
|
||||
data.append(go.Scatter(
|
||||
name='Predicted',
|
||||
x=fcst['ds'],
|
||||
y=fcst['yhat'],
|
||||
mode='lines',
|
||||
line=dict(color=prediction_color, width=line_width),
|
||||
fillcolor=error_color,
|
||||
fill='tonexty' if uncertainty else 'none'
|
||||
))
|
||||
# Add upper bound
|
||||
if uncertainty:
|
||||
data.append(go.Scatter(
|
||||
x=fcst['ds'],
|
||||
y=fcst['yhat_upper'],
|
||||
mode='lines',
|
||||
line=dict(width=0),
|
||||
fillcolor=error_color,
|
||||
fill='tonexty',
|
||||
hoverinfo='skip'
|
||||
))
|
||||
# Add caps
|
||||
if 'cap' in fcst and plot_cap:
|
||||
data.append(go.Scatter(
|
||||
name='Cap',
|
||||
x=fcst['ds'],
|
||||
y=fcst['cap'],
|
||||
mode='lines',
|
||||
line=dict(color=cap_color, dash='dash', width=line_width),
|
||||
))
|
||||
if m.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
data.append(go.Scatter(
|
||||
name='Floor',
|
||||
x=fcst['ds'],
|
||||
y=fcst['floor'],
|
||||
mode='lines',
|
||||
line=dict(color=cap_color, dash='dash', width=line_width),
|
||||
))
|
||||
# Add trend
|
||||
if trend:
|
||||
data.append(go.Scatter(
|
||||
name='Trend',
|
||||
x=fcst['ds'],
|
||||
y=fcst['trend'],
|
||||
mode='lines',
|
||||
line=dict(color=trend_color, width=line_width),
|
||||
))
|
||||
# Add changepoints
|
||||
if changepoints:
|
||||
signif_changepoints = m.changepoints[
|
||||
np.abs(np.nanmean(m.params['delta'], axis=0)) >= changepoints_threshold
|
||||
]
|
||||
data.append(go.Scatter(
|
||||
x=signif_changepoints,
|
||||
y=fcst.loc[fcst['ds'].isin(signif_changepoints), 'trend'],
|
||||
marker=dict(size=50, symbol='line-ns-open', color=trend_color,
|
||||
line=dict(width=line_width)),
|
||||
mode='markers',
|
||||
hoverinfo='skip'
|
||||
))
|
||||
|
||||
layout = dict(
|
||||
showlegend=False,
|
||||
yaxis=dict(
|
||||
title=ylabel
|
||||
),
|
||||
xaxis=dict(
|
||||
title=xlabel,
|
||||
type='date',
|
||||
rangeselector=dict(
|
||||
buttons=list([
|
||||
dict(count=7,
|
||||
label='1w',
|
||||
step='day',
|
||||
stepmode='backward'),
|
||||
dict(count=1,
|
||||
label='1m',
|
||||
step='month',
|
||||
stepmode='backward'),
|
||||
dict(count=6,
|
||||
label='6m',
|
||||
step='month',
|
||||
stepmode='backward'),
|
||||
dict(count=1,
|
||||
label='1y',
|
||||
step='year',
|
||||
stepmode='backward'),
|
||||
dict(step='all')
|
||||
])
|
||||
),
|
||||
rangeslider=dict(
|
||||
visible=True
|
||||
),
|
||||
),
|
||||
)
|
||||
fig = go.Figure(data=data, layout=layout)
|
||||
return plotly.offline.iplot(fig)
|
||||
|
|
|
|||
Loading…
Reference in a new issue