mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-07-03 03:59:00 +00:00
Limit n_changepoints to number of observations.
This commit is contained in:
parent
79d0793ce4
commit
0b4ec4a9b3
4 changed files with 56 additions and 15 deletions
|
|
@ -329,10 +329,16 @@ set_changepoints <- function(m) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
# Place potential changepoints evenly through the first 80 pcnt of
|
||||
# the history.
|
||||
hist.size <- floor(nrow(m$history) * .8)
|
||||
if (m$n.changepoints + 1 > hist.size) {
|
||||
m$n.changepoints <- hist.size - 1
|
||||
warning('n.changepoints greater than number of observations. Using ',
|
||||
m$n.changepoints)
|
||||
}
|
||||
if (m$n.changepoints > 0) {
|
||||
# Place potential changepoints evenly through the first 80 pcnt of
|
||||
# the history.
|
||||
cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
|
||||
cp.indexes <- round(seq.int(1, hist.size,
|
||||
length.out = (m$n.changepoints + 1))[-1])
|
||||
m$changepoints <- m$history$ds[cp.indexes]
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -101,6 +101,21 @@ test_that("get_zero_changepoints", {
|
|||
expect_equal(ncol(mat), 1)
|
||||
})
|
||||
|
||||
test_that("override_n_changepoints", {
|
||||
history <- train[1:20,]
|
||||
m <- prophet(history, fit = FALSE)
|
||||
|
||||
out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
|
||||
m <- out$m
|
||||
history <- out$df
|
||||
m$history <- history
|
||||
|
||||
m <- prophet:::set_changepoints(m)
|
||||
expect_equal(m$n.changepoints, 15)
|
||||
cp <- m$changepoints.t
|
||||
expect_equal(length(cp), 15)
|
||||
})
|
||||
|
||||
test_that("fourier_series_weekly", {
|
||||
mat <- prophet:::fourier_series(DATA$ds, 7, 3)
|
||||
true.values <- c(0.9165623, 0.3998920, 0.7330519, -0.6801727, -0.3302791,
|
||||
|
|
|
|||
|
|
@ -277,19 +277,27 @@ class Prophet(object):
|
|||
too_low = min(self.changepoints) < self.history['ds'].min()
|
||||
too_high = max(self.changepoints) > self.history['ds'].max()
|
||||
if too_low or too_high:
|
||||
raise ValueError('Changepoints must fall within training data.')
|
||||
elif self.n_changepoints > 0:
|
||||
# Place potential changepoints evenly through first 80% of history
|
||||
max_ix = np.floor(self.history.shape[0] * 0.8)
|
||||
cp_indexes = (
|
||||
np.linspace(0, max_ix, self.n_changepoints + 1)
|
||||
.round()
|
||||
.astype(np.int)
|
||||
)
|
||||
self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1)
|
||||
raise ValueError(
|
||||
'Changepoints must fall within training data.')
|
||||
else:
|
||||
# set empty changepoints
|
||||
self.changepoints = []
|
||||
# Place potential changepoints evenly through first 80% of history
|
||||
hist_size = np.floor(self.history.shape[0] * 0.8)
|
||||
if self.n_changepoints + 1 > hist_size:
|
||||
self.n_changepoints = hist_size - 1
|
||||
logger.info(
|
||||
'n_changepoints greater than number of observations.'
|
||||
'Using {}.'.format(self.n_changepoints)
|
||||
)
|
||||
if self.n_changepoints > 0:
|
||||
cp_indexes = (
|
||||
np.linspace(0, hist_size, self.n_changepoints + 1)
|
||||
.round()
|
||||
.astype(np.int)
|
||||
)
|
||||
self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1)
|
||||
else:
|
||||
# set empty changepoints
|
||||
self.changepoints = []
|
||||
if len(self.changepoints) > 0:
|
||||
self.changepoints_t = np.sort(np.array(
|
||||
(self.changepoints - self.start) / self.t_scale))
|
||||
|
|
|
|||
|
|
@ -130,6 +130,18 @@ class TestProphet(TestCase):
|
|||
self.assertEqual(mat.shape[0], N // 2)
|
||||
self.assertEqual(mat.shape[1], 1)
|
||||
|
||||
def test_override_n_changepoints(self):
|
||||
m = Prophet()
|
||||
history = DATA.head(20).copy()
|
||||
|
||||
history = m.setup_dataframe(history, initialize_scales=True)
|
||||
m.history = history
|
||||
|
||||
m.set_changepoints()
|
||||
self.assertEqual(m.n_changepoints, 15)
|
||||
cp = m.changepoints_t
|
||||
self.assertEqual(cp.shape[0], 15)
|
||||
|
||||
def test_fourier_series_weekly(self):
|
||||
mat = Prophet.fourier_series(DATA['ds'], 7, 3)
|
||||
# These are from the R forecast package directly.
|
||||
|
|
|
|||
Loading…
Reference in a new issue