Added a interactive Plotly plot of the forecast (#915)

This commit is contained in:
Olof Höjvall 2019-04-18 01:47:16 +02:00 committed by Ben Letham
parent a31a93480a
commit c756d20100

View file

@ -25,6 +25,12 @@ try:
except ImportError:
logger.error('Importing matplotlib failed. Plotting will not work.')
try:
import plotly.graph_objs as go
import plotly.offline
except ImportError:
logger.error('Importing plotly failed. Interactive plots will not work.')
def plot(
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
@ -495,3 +501,152 @@ def plot_cross_validation_metric(
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
ax.set_ylabel(metric)
return fig
def plot_plotly(m, fcst, uncertainty=True, plot_cap=True, trend=False, changepoints=False,
changepoints_threshold=0.01, xlabel='ds', ylabel='y'):
"""Plot the Prophet forecast with Plotly offline.
Requires plotly.offline.init_notebook_mode() to have been run,
see https://plot.ly/python/offline/ for details
Parameters
----------
m: Prophet model.
fcst: pd.DataFrame output of m.predict.
uncertainty: Optional boolean to plot uncertainty intervals.
plot_cap: Optional boolean indicating if the capacity should be shown
in the figure, if available.
trend: Optional boolean to plot trend
changepoints: Optional boolean to plot changepoints
changepoints_threshold: Threshold on trend change magnitude for significance.
xlabel: Optional label name on X-axis
ylabel: Optional label name on Y-axis
Returns
-------
A Plotly plot.
"""
prediction_color = '#0072B2'
error_color = 'rgba(0, 114, 178, 0.2)' # '#0072B2' with 0.2 opacity
actual_color = 'black'
cap_color = 'black'
trend_color = '#B23B00'
line_width = 2
marker_size = 4
data = []
# Add actual
data.append(go.Scatter(
name='Actual',
x=m.history['ds'],
y=m.history['y'],
marker=dict(color=actual_color, size=marker_size),
mode='markers'
))
# Add lower bound
if uncertainty:
data.append(go.Scatter(
x=fcst['ds'],
y=fcst['yhat_lower'],
mode='lines',
line=dict(width=0),
hoverinfo='skip'
))
# Add prediction
data.append(go.Scatter(
name='Predicted',
x=fcst['ds'],
y=fcst['yhat'],
mode='lines',
line=dict(color=prediction_color, width=line_width),
fillcolor=error_color,
fill='tonexty' if uncertainty else 'none'
))
# Add upper bound
if uncertainty:
data.append(go.Scatter(
x=fcst['ds'],
y=fcst['yhat_upper'],
mode='lines',
line=dict(width=0),
fillcolor=error_color,
fill='tonexty',
hoverinfo='skip'
))
# Add caps
if 'cap' in fcst and plot_cap:
data.append(go.Scatter(
name='Cap',
x=fcst['ds'],
y=fcst['cap'],
mode='lines',
line=dict(color=cap_color, dash='dash', width=line_width),
))
if m.logistic_floor and 'floor' in fcst and plot_cap:
data.append(go.Scatter(
name='Floor',
x=fcst['ds'],
y=fcst['floor'],
mode='lines',
line=dict(color=cap_color, dash='dash', width=line_width),
))
# Add trend
if trend:
data.append(go.Scatter(
name='Trend',
x=fcst['ds'],
y=fcst['trend'],
mode='lines',
line=dict(color=trend_color, width=line_width),
))
# Add changepoints
if changepoints:
signif_changepoints = m.changepoints[
np.abs(np.nanmean(m.params['delta'], axis=0)) >= changepoints_threshold
]
data.append(go.Scatter(
x=signif_changepoints,
y=fcst.loc[fcst['ds'].isin(signif_changepoints), 'trend'],
marker=dict(size=50, symbol='line-ns-open', color=trend_color,
line=dict(width=line_width)),
mode='markers',
hoverinfo='skip'
))
layout = dict(
showlegend=False,
yaxis=dict(
title=ylabel
),
xaxis=dict(
title=xlabel,
type='date',
rangeselector=dict(
buttons=list([
dict(count=7,
label='1w',
step='day',
stepmode='backward'),
dict(count=1,
label='1m',
step='month',
stepmode='backward'),
dict(count=6,
label='6m',
step='month',
stepmode='backward'),
dict(count=1,
label='1y',
step='year',
stepmode='backward'),
dict(step='all')
])
),
rangeslider=dict(
visible=True
),
),
)
fig = go.Figure(data=data, layout=layout)
return plotly.offline.iplot(fig)