From b4ee8f1b0a08250f95c98014f326f6e22333deea Mon Sep 17 00:00:00 2001 From: iMad Date: Sat, 18 May 2024 13:56:55 +0200 Subject: [PATCH] Include predictions for missing y (NaN) dates in the history (#2530) Co-authored-by: Imad Rahmouni --- python/prophet/forecaster.py | 2 +- python/prophet/tests/test_prophet.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/prophet/forecaster.py b/python/prophet/forecaster.py index 47e368f..5905bf3 100644 --- a/python/prophet/forecaster.py +++ b/python/prophet/forecaster.py @@ -1130,7 +1130,7 @@ class Prophet(object): history = df[df['y'].notnull()].copy() if history.shape[0] < 2: raise ValueError('Dataframe has less than 2 non-NaN rows.') - self.history_dates = pd.to_datetime(pd.Series(history['ds'].unique(), name='ds')).sort_values() + self.history_dates = pd.to_datetime(pd.Series(df['ds'].unique(), name='ds')).sort_values() self.history = self.setup_dataframe(history, initialize_scales=True) self.set_auto_seasonalities() diff --git a/python/prophet/tests/test_prophet.py b/python/prophet/tests/test_prophet.py index 83ea4aa..3df052d 100644 --- a/python/prophet/tests/test_prophet.py +++ b/python/prophet/tests/test_prophet.py @@ -244,6 +244,16 @@ class TestProphetDataPrep: assert len(future) == 3 assert np.all(future["ds"].values == correct.values) + def test_make_future_dataframe_include_history(self, daily_univariate_ts, backend): + train = daily_univariate_ts.head(468 // 2).copy() + #cover history with NAs + train.loc[train.sample(10).index, "y"] = np.nan + + forecaster = Prophet(stan_backend=backend) + forecaster.fit(train) + future = forecaster.make_future_dataframe(periods=3, freq="D", include_history=True) + + assert len(future) == train.shape[0] + 3 class TestProphetTrendComponent: def test_invalid_growth_input(self, backend):