mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-22 22:01:14 +00:00
Add include_legend flag to m.plot() (#1858)
This commit is contained in:
parent
c72ed7abcd
commit
899b4fac13
2 changed files with 13 additions and 8 deletions
|
|
@ -1577,7 +1577,7 @@ class Prophet(object):
|
|||
return pd.DataFrame({'ds': dates})
|
||||
|
||||
def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True,
|
||||
xlabel='ds', ylabel='y', figsize=(10, 6)):
|
||||
xlabel='ds', ylabel='y', figsize=(10, 6), include_legend=False):
|
||||
"""Plot the Prophet forecast.
|
||||
|
||||
Parameters
|
||||
|
|
@ -1590,6 +1590,7 @@ class Prophet(object):
|
|||
xlabel: Optional label name on X-axis
|
||||
ylabel: Optional label name on Y-axis
|
||||
figsize: Optional tuple width, height in inches.
|
||||
include_legend: Optional boolean to add legend to the plot.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -1598,7 +1599,7 @@ class Prophet(object):
|
|||
return plot(
|
||||
m=self, fcst=fcst, ax=ax, uncertainty=uncertainty,
|
||||
plot_cap=plot_cap, xlabel=xlabel, ylabel=ylabel,
|
||||
figsize=figsize
|
||||
figsize=figsize, include_legend=include_legend
|
||||
)
|
||||
|
||||
def plot_components(self, fcst, uncertainty=True, plot_cap=True,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ except ImportError:
|
|||
|
||||
def plot(
|
||||
m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
|
||||
figsize=(10, 6)
|
||||
figsize=(10, 6), include_legend=False
|
||||
):
|
||||
"""Plot the Prophet forecast.
|
||||
|
||||
|
|
@ -57,6 +57,7 @@ def plot(
|
|||
xlabel: Optional label name on X-axis
|
||||
ylabel: Optional label name on Y-axis
|
||||
figsize: Optional tuple width, height in inches.
|
||||
include_legend: Optional boolean to add legend to the plot.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -68,15 +69,16 @@ def plot(
|
|||
else:
|
||||
fig = ax.get_figure()
|
||||
fcst_t = fcst['ds'].dt.to_pydatetime()
|
||||
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
|
||||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
|
||||
ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.',
|
||||
label='Observed data points')
|
||||
ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2', label='Forecast')
|
||||
if 'cap' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
|
||||
ax.plot(fcst_t, fcst['cap'], ls='--', c='k', label='Maximum capacity')
|
||||
if m.logistic_floor and 'floor' in fcst and plot_cap:
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
|
||||
ax.plot(fcst_t, fcst['floor'], ls='--', c='k', label='Minimum capacity')
|
||||
if uncertainty and m.uncertainty_samples:
|
||||
ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
|
||||
color='#0072B2', alpha=0.2)
|
||||
color='#0072B2', alpha=0.2, label='Uncertainty interval')
|
||||
# Specify formatting to workaround matplotlib issue #12925
|
||||
locator = AutoDateLocator(interval_multiples=False)
|
||||
formatter = AutoDateFormatter(locator)
|
||||
|
|
@ -85,6 +87,8 @@ def plot(
|
|||
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
if include_legend:
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue