mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-03 23:49:47 +00:00
Handle constant y in history
This commit is contained in:
parent
0b4ec4a9b3
commit
e4ec600da4
4 changed files with 47 additions and 2 deletions
|
|
@ -286,6 +286,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|||
|
||||
if (initialize_scales) {
|
||||
m$y.scale <- max(abs(df$y))
|
||||
if (m$y.scale == 0) {
|
||||
m$y.scale <- 1
|
||||
}
|
||||
m$start <- min(df$ds)
|
||||
m$t.scale <- time_diff(max(df$ds), m$start, "secs")
|
||||
}
|
||||
|
|
@ -703,7 +706,12 @@ fit.prophet <- function(m, df, ...) {
|
|||
)
|
||||
}
|
||||
|
||||
if (m$mcmc.samples > 0) {
|
||||
if (min(history$y) == max(history$y)) {
|
||||
# Nothing to fit.
|
||||
m$params <- stan_init()
|
||||
m$params$sigma_obs <- 0.
|
||||
n.iteration <- 1.
|
||||
} else if (m$mcmc.samples > 0) {
|
||||
stan.fit <- rstan::sampling(
|
||||
model,
|
||||
data = dat,
|
||||
|
|
|
|||
|
|
@ -46,6 +46,19 @@ test_that("fit_predict_duplicates", {
|
|||
expect_error(predict(m, future), NA)
|
||||
})
|
||||
|
||||
test_that("fit_predict_constant_history", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
train2 <- train
|
||||
train2$y <- 20
|
||||
m <- prophet(train2)
|
||||
fcst <- predict(m, future)
|
||||
expect_equal(tail(fcst$yhat, 1), 20)
|
||||
train2$y <- 0
|
||||
m <- prophet(train2)
|
||||
fcst <- predict(m, future)
|
||||
expect_equal(tail(fcst$yhat, 1), 0)
|
||||
})
|
||||
|
||||
test_that("setup_dataframe", {
|
||||
history <- train
|
||||
m <- prophet(history, fit = FALSE)
|
||||
|
|
|
|||
|
|
@ -227,6 +227,8 @@ class Prophet(object):
|
|||
|
||||
if initialize_scales:
|
||||
self.y_scale = df['y'].abs().max()
|
||||
if self.y_scale == 0:
|
||||
self.y_scale = 1
|
||||
self.start = df['ds'].min()
|
||||
self.t_scale = df['ds'].max() - self.start
|
||||
for name, props in self.extra_regressors.items():
|
||||
|
|
@ -726,7 +728,13 @@ class Prophet(object):
|
|||
'sigma_obs': 1,
|
||||
}
|
||||
|
||||
if self.mcmc_samples > 0:
|
||||
if history['y'].min() == history['y'].max():
|
||||
# Nothing to fit.
|
||||
self.params = stan_init()
|
||||
self.params['sigma_obs'] = 0.
|
||||
for par in self.params:
|
||||
self.params[par] = np.array([self.params[par]])
|
||||
elif self.mcmc_samples > 0:
|
||||
stan_fit = model.sampling(
|
||||
dat,
|
||||
init=stan_init,
|
||||
|
|
|
|||
|
|
@ -79,6 +79,22 @@ class TestProphet(TestCase):
|
|||
forecaster.fit(train)
|
||||
forecaster.predict(future)
|
||||
|
||||
def test_fit_predict_constant_history(self):
|
||||
N = DATA.shape[0]
|
||||
train = DATA.head(N // 2).copy()
|
||||
train['y'] = 20
|
||||
future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)})
|
||||
m = Prophet()
|
||||
m.fit(train)
|
||||
fcst = m.predict(future)
|
||||
self.assertEqual(fcst['yhat'].values[-1], 20)
|
||||
train['y'] = 0
|
||||
future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)})
|
||||
m = Prophet()
|
||||
m.fit(train)
|
||||
fcst = m.predict(future)
|
||||
self.assertEqual(fcst['yhat'].values[-1], 0)
|
||||
|
||||
def test_setup_dataframe(self):
|
||||
m = Prophet()
|
||||
N = DATA.shape[0]
|
||||
|
|
|
|||
Loading…
Reference in a new issue