From c756d20100a2855cb49f2405883c6975ec174b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olof=20H=C3=B6jvall?= Date: Thu, 18 Apr 2019 01:47:16 +0200 Subject: [PATCH] Added a interactive Plotly plot of the forecast (#915) --- python/fbprophet/plot.py | 155 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/python/fbprophet/plot.py b/python/fbprophet/plot.py index 96c1086..2f385a0 100644 --- a/python/fbprophet/plot.py +++ b/python/fbprophet/plot.py @@ -25,6 +25,12 @@ try: except ImportError: logger.error('Importing matplotlib failed. Plotting will not work.') +try: + import plotly.graph_objs as go + import plotly.offline +except ImportError: + logger.error('Importing plotly failed. Interactive plots will not work.') + def plot( m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y', @@ -495,3 +501,152 @@ def plot_cross_validation_metric( ax.set_xlabel('Horizon ({})'.format(dt_names[i])) ax.set_ylabel(metric) return fig + + +def plot_plotly(m, fcst, uncertainty=True, plot_cap=True, trend=False, changepoints=False, + changepoints_threshold=0.01, xlabel='ds', ylabel='y'): + """Plot the Prophet forecast with Plotly offline. + + Requires plotly.offline.init_notebook_mode() to have been run, + see https://plot.ly/python/offline/ for details + + Parameters + ---------- + m: Prophet model. + fcst: pd.DataFrame output of m.predict. + uncertainty: Optional boolean to plot uncertainty intervals. + plot_cap: Optional boolean indicating if the capacity should be shown + in the figure, if available. + trend: Optional boolean to plot trend + changepoints: Optional boolean to plot changepoints + changepoints_threshold: Threshold on trend change magnitude for significance. + xlabel: Optional label name on X-axis + ylabel: Optional label name on Y-axis + + Returns + ------- + A Plotly plot. + """ + prediction_color = '#0072B2' + error_color = 'rgba(0, 114, 178, 0.2)' # '#0072B2' with 0.2 opacity + actual_color = 'black' + cap_color = 'black' + trend_color = '#B23B00' + line_width = 2 + marker_size = 4 + + data = [] + # Add actual + data.append(go.Scatter( + name='Actual', + x=m.history['ds'], + y=m.history['y'], + marker=dict(color=actual_color, size=marker_size), + mode='markers' + )) + # Add lower bound + if uncertainty: + data.append(go.Scatter( + x=fcst['ds'], + y=fcst['yhat_lower'], + mode='lines', + line=dict(width=0), + hoverinfo='skip' + )) + # Add prediction + data.append(go.Scatter( + name='Predicted', + x=fcst['ds'], + y=fcst['yhat'], + mode='lines', + line=dict(color=prediction_color, width=line_width), + fillcolor=error_color, + fill='tonexty' if uncertainty else 'none' + )) + # Add upper bound + if uncertainty: + data.append(go.Scatter( + x=fcst['ds'], + y=fcst['yhat_upper'], + mode='lines', + line=dict(width=0), + fillcolor=error_color, + fill='tonexty', + hoverinfo='skip' + )) + # Add caps + if 'cap' in fcst and plot_cap: + data.append(go.Scatter( + name='Cap', + x=fcst['ds'], + y=fcst['cap'], + mode='lines', + line=dict(color=cap_color, dash='dash', width=line_width), + )) + if m.logistic_floor and 'floor' in fcst and plot_cap: + data.append(go.Scatter( + name='Floor', + x=fcst['ds'], + y=fcst['floor'], + mode='lines', + line=dict(color=cap_color, dash='dash', width=line_width), + )) + # Add trend + if trend: + data.append(go.Scatter( + name='Trend', + x=fcst['ds'], + y=fcst['trend'], + mode='lines', + line=dict(color=trend_color, width=line_width), + )) + # Add changepoints + if changepoints: + signif_changepoints = m.changepoints[ + np.abs(np.nanmean(m.params['delta'], axis=0)) >= changepoints_threshold + ] + data.append(go.Scatter( + x=signif_changepoints, + y=fcst.loc[fcst['ds'].isin(signif_changepoints), 'trend'], + marker=dict(size=50, symbol='line-ns-open', color=trend_color, + line=dict(width=line_width)), + mode='markers', + hoverinfo='skip' + )) + + layout = dict( + showlegend=False, + yaxis=dict( + title=ylabel + ), + xaxis=dict( + title=xlabel, + type='date', + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + rangeslider=dict( + visible=True + ), + ), + ) + fig = go.Figure(data=data, layout=layout) + return plotly.offline.iplot(fig)