mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-11 00:49:35 +00:00
Fix issue in Python when changepoint_range=1
This commit is contained in:
parent
cb0b47994b
commit
68ff9e577d
3 changed files with 7 additions and 1 deletions
|
|
@ -139,6 +139,8 @@ test_that("set_changepoint_range", {
|
|||
expect_equal(length(cp), m$n.changepoints)
|
||||
expect_true(min(cp) > 0)
|
||||
expect_true(max(cp) <= history$t[ceiling(0.4 * length(history$t))])
|
||||
expect_error(prophet(history, changepoint.range = -0.1))
|
||||
expect_error(prophet(history, changepoint.range = 2))
|
||||
})
|
||||
|
||||
test_that("get_zero_changepoints", {
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ class Prophet(object):
|
|||
)
|
||||
if self.n_changepoints > 0:
|
||||
cp_indexes = (
|
||||
np.linspace(0, hist_size, self.n_changepoints + 1)
|
||||
np.linspace(0, hist_size - 1, self.n_changepoints + 1)
|
||||
.round()
|
||||
.astype(np.int)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -169,6 +169,10 @@ class TestProphet(TestCase):
|
|||
self.assertTrue(cp.min() > 0)
|
||||
cp_indx = int(np.ceil(0.4 * history.shape[0]))
|
||||
self.assertTrue(cp.max() <= history['t'].values[cp_indx])
|
||||
with self.assertRaises(ValueError):
|
||||
m = Prophet(changepoint_range=-0.1)
|
||||
with self.assertRaises(ValueError):
|
||||
m = Prophet(changepoint_range=2)
|
||||
|
||||
def test_get_zero_changepoints(self):
|
||||
m = Prophet(n_changepoints=0)
|
||||
|
|
|
|||
Loading…
Reference in a new issue