diff --git a/python/fbprophet/tests/test_diagnostics.py b/python/fbprophet/tests/test_diagnostics.py index be8fa40..46c0ee8 100644 --- a/python/fbprophet/tests/test_diagnostics.py +++ b/python/fbprophet/tests/test_diagnostics.py @@ -105,6 +105,14 @@ class TestDiagnostics(TestCase): self.assertAlmostEqual( ((df_cv1['yhat'] - df_cv2['yhat']) ** 2).sum(), 0.0) + def test_cross_validation_uncertainty_disabled(self): + df = self.__df.copy() + m = Prophet(uncertainty_samples=0) + m.fit(df) + df_cv = diagnostics.cross_validation( + m, horizon='4 days', period='4 days', initial='115 days') + self.assertListEqual(['ds', 'yhat', 'y', 'cutoff'], df_cv.columns.tolist()) + def test_performance_metrics(self): m = Prophet() m.fit(self.__df) diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 4675824..b052e84 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -101,10 +101,9 @@ class TestProphet(TestCase): m = Prophet(uncertainty_samples=0) m.fit(train) fcst = m.predict(future) - self.assertNotIn('yhat_lower', list(fcst.columns)) - self.assertNotIn('yhat_upper', list(fcst.columns)) - self.assertNotIn('trend_lower', list(fcst.columns)) - self.assertNotIn('trend_upper', list(fcst.columns)) + self.assertListEqual(['ds', 'trend', 'additive_terms', 'weekly', + 'multiplicative_terms', 'yhat'], fcst.columns.tolist()) + def test_setup_dataframe(self): m = Prophet()