Add include_legend flag to m.plot() (#1858)

This commit is contained in:
Vladimir Shargin 2021-04-04 01:01:41 +03:00 committed by GitHub
parent c72ed7abcd
commit 899b4fac13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 8 deletions

View file

@ -1577,7 +1577,7 @@ class Prophet(object):
return pd.DataFrame({'ds': dates})
def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True,
xlabel='ds', ylabel='y', figsize=(10, 6)):
xlabel='ds', ylabel='y', figsize=(10, 6), include_legend=False):
"""Plot the Prophet forecast.
Parameters
@ -1590,6 +1590,7 @@ class Prophet(object):
xlabel: Optional label name on X-axis
ylabel: Optional label name on Y-axis
figsize: Optional tuple width, height in inches.
include_legend: Optional boolean to add legend to the plot.
Returns
-------
@ -1598,7 +1599,7 @@ class Prophet(object):
return plot(
m=self, fcst=fcst, ax=ax, uncertainty=uncertainty,
plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel,
figsize=figsize
figsize=figsize, include_legend=include_legend
)
def plot_components(self, fcst, uncertainty=True, plot_cap=True,

View file

@ -41,7 +41,7 @@ except ImportError:
def plot(
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
figsize=(10, 6)
figsize=(10, 6), include_legend=False
):
"""Plot the Prophet forecast.
@ -57,6 +57,7 @@ def plot(
xlabel: Optional label name on X-axis
ylabel: Optional label name on Y-axis
figsize: Optional tuple width, height in inches.
include_legend: Optional boolean to add legend to the plot.
Returns
-------
@ -68,15 +69,16 @@ def plot(
else:
fig = ax.get_figure()
fcst_t = fcst['ds'].dt.to_pydatetime()
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.',
label='Observed data points')
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2', label='Forecast')
if 'cap' in fcst and plot_cap:
ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
ax.plot(fcst_t, fcst['cap'], ls='--', c='k', label='Maximum capacity')
if m.logistic_floor and 'floor' in fcst and plot_cap:
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
ax.plot(fcst_t, fcst['floor'], ls='--', c='k', label='Minimum capacity')
if uncertainty and m.uncertainty_samples:
ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
color='#0072B2', alpha=0.2)
color='#0072B2', alpha=0.2, label='Uncertainty interval')
# Specify formatting to workaround matplotlib issue #12925
locator = AutoDateLocator(interval_multiples=False)
formatter = AutoDateFormatter(locator)
@ -85,6 +87,8 @@ def plot(
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if include_legend:
ax.legend()
fig.tight_layout()
return fig