From 899b4fac13d2cad87aebd19692bea1332e8b4ed1 Mon Sep 17 00:00:00 2001 From: Vladimir Shargin Date: Sun, 4 Apr 2021 01:01:41 +0300 Subject: [PATCH] Add include_legend flag to m.plot() (#1858) --- python/prophet/forecaster.py | 5 +++-- python/prophet/plot.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/prophet/forecaster.py b/python/prophet/forecaster.py index 9f75175..7d02b85 100644 --- a/python/prophet/forecaster.py +++ b/python/prophet/forecaster.py @@ -1577,7 +1577,7 @@ class Prophet(object): return pd.DataFrame({'ds': dates}) def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True, - xlabel='ds', ylabel='y', figsize=(10, 6)): + xlabel='ds', ylabel='y', figsize=(10, 6), include_legend=False): """Plot the Prophet forecast. Parameters @@ -1590,6 +1590,7 @@ class Prophet(object): xlabel: Optional label name on X-axis ylabel: Optional label name on Y-axis figsize: Optional tuple width, height in inches. + include_legend: Optional boolean to add legend to the plot. Returns ------- @@ -1598,7 +1599,7 @@ class Prophet(object): return plot( m=self, fcst=fcst, ax=ax, uncertainty=uncertainty, plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel, - figsize=figsize + figsize=figsize, include_legend=include_legend ) def plot_components(self, fcst, uncertainty=True, plot_cap=True, diff --git a/python/prophet/plot.py b/python/prophet/plot.py index 10269a9..cb923ca 100644 --- a/python/prophet/plot.py +++ b/python/prophet/plot.py @@ -41,7 +41,7 @@ except ImportError: def plot( m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y', - figsize=(10, 6) + figsize=(10, 6), include_legend=False ): """Plot the Prophet forecast. @@ -57,6 +57,7 @@ def plot( xlabel: Optional label name on X-axis ylabel: Optional label name on Y-axis figsize: Optional tuple width, height in inches. + include_legend: Optional boolean to add legend to the plot. Returns ------- @@ -68,15 +69,16 @@ def plot( else: fig = ax.get_figure() fcst_t = fcst['ds'].dt.to_pydatetime() - ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.') - ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2') + ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.', + label='Observed data points') + ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2', label='Forecast') if 'cap' in fcst and plot_cap: - ax.plot(fcst_t, fcst['cap'], ls='--', c='k') + ax.plot(fcst_t, fcst['cap'], ls='--', c='k', label='Maximum capacity') if m.logistic_floor and 'floor' in fcst and plot_cap: - ax.plot(fcst_t, fcst['floor'], ls='--', c='k') + ax.plot(fcst_t, fcst['floor'], ls='--', c='k', label='Minimum capacity') if uncertainty and m.uncertainty_samples: ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'], - color='#0072B2', alpha=0.2) + color='#0072B2', alpha=0.2, label='Uncertainty interval') # Specify formatting to workaround matplotlib issue #12925 locator = AutoDateLocator(interval_multiples=False) formatter = AutoDateFormatter(locator) @@ -85,6 +87,8 @@ def plot( ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) + if include_legend: + ax.legend() fig.tight_layout() return fig