mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-26 22:35:48 +00:00
Fix bugs that were introduced into R cross validation
This commit is contained in:
parent
aeaf0ad2e1
commit
4fcecdb3df
2 changed files with 6 additions and 5 deletions
|
|
@ -138,8 +138,9 @@ cross_validation <- function(
|
|||
future <- df.predict[columns]
|
||||
yhat <- stats::predict(m, future)
|
||||
# Merge yhat, y, and cutoff.
|
||||
df.c <- dplyr::inner_join(df.predict, yhat$predict_columns, by = "ds")
|
||||
df.c <- dplyr::select(df.c, y, yhat$predict_columns)
|
||||
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
|
||||
df.c <- df.c[c(predict_columns, "y")]
|
||||
df.c <- dplyr::select(df.c, y, predict_columns)
|
||||
df.c$cutoff <- cutoff
|
||||
predicts <- rbind(predicts, df.c)
|
||||
}
|
||||
|
|
@ -240,7 +241,7 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
|
|||
if (is.null(metrics)) {
|
||||
metrics <- valid_metrics
|
||||
}
|
||||
if (!('yhat_lower' %in% df) | (!('yhat_upper' %in% df)) & ('coverage' %in% metrics)){
|
||||
if (!('yhat_lower' %in% colnames(df)) | (!('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){
|
||||
metrics <- valid_metrics[valid_metrics != 'coverage']
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ test_that("cross_validation", {
|
|||
m, horizon = 4, units = "days", period = 10, initial = 115)
|
||||
expect_equal(length(unique(df.cv$cutoff)), 3)
|
||||
expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
|
||||
expect_true(min(df.cv$cutoff) >= ts + initial)
|
||||
expect_true(as.Date(min(df.cv$cutoff)) >= ts + initial)
|
||||
dc <- diff(df.cv$cutoff)
|
||||
dc <- min(dc[dc > 0])
|
||||
expect_true(dc >= period)
|
||||
|
|
@ -131,7 +131,7 @@ test_that("performance_metrics", {
|
|||
expect_null(df_horizon)
|
||||
# List of metrics containing non valid metrics
|
||||
expect_error(
|
||||
performance_metrics(df, metrics = c('mse', 'error_metric')),
|
||||
performance_metrics(df_cv, metrics = c('mse', 'error_metric')),
|
||||
'Valid values for metrics are: mse, rmse, mae, mape, coverage'
|
||||
)
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in a new issue