diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index f221aa2..709bb7c 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -216,10 +216,11 @@ prophet_copy <- function(m, cutoff = NULL) { #' #' Computes a suite of performance metrics on the output of cross-validation. #' By default the following metrics are included: -#' 'mse': mean squared error -#' 'rmse': root mean squared error -#' 'mae': mean absolute error -#' 'mape': mean percent error +#' 'mse': mean squared error, +#' 'rmse': root mean squared error, +#' 'mae': mean absolute error, +#' 'mape': mean percent error, +#' 'mdape': median percent error, #' 'coverage': coverage of the upper and lower intervals #' #' A subset of these can be specified by passing a list of names as the @@ -244,7 +245,7 @@ prophet_copy <- function(m, cutoff = NULL) { #' #' @param df The dataframe returned by cross_validation. #' @param metrics An array of performance metrics to compute. If not provided, -#' will use c('mse', 'rmse', 'mae', 'mape', 'coverage'). +#' will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage'). #' @param rolling_window Proportion of data to use in each rolling window for #' computing the metrics. Should be in [0, 1] to average. #' @@ -275,6 +276,10 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { message('Skipping MAPE because y close to 0') metrics <- metrics[metrics != 'mape'] } + if (('mdape' %in% metrics) & (min(abs(df_m$y)) < 1e-8)) { + message('Skipping MDAPE because y close to 0') + metrics <- metrics[metrics != 'mdape'] + } if (length(metrics) == 0) { return(NULL) } @@ -351,6 +356,64 @@ rolling_mean_by_h <- function(x, h, w, name) { return(res) } + +#' Compute a rolling median of x, after first aggregating by h +#' +#' Right-aligned. Computes a single median for each unique value of h. Each median +#' is over at least w samples. +#' +#' For each h where there are fewer than w samples, we take samples from the previous h, +# moving backwards. (In other words, we ~ assume that the x's are shuffled within each h.) +#' +#' @param x Array. +#' @param h Array of horizon for each value in x. +#' @param w Integer window size (number of elements). +#' @param name String name for metric in result dataframe. +#' +#' @return Dataframe with columns horizon and name, the rolling median of x. +#' +#' @importFrom dplyr "%>%" +rolling_median_by_h <- function(x, h, w, name) { + # Aggregate over h + df <- data.frame(x=x, h=h) + grouped <- df %>% dplyr::group_by(h) + df2 <- grouped %>% + dplyr::summarise(size=dplyr::n()) %>% + dplyr::arrange(h) %>% + dplyr::select(h, size) + + hs <- df2$h + res <- data.frame(horizon=c()) + res[[name]] <- c() + + # Start from the right and work backwards + i <- length(hs) + while (i > 0) { + h_i <- hs[i] + xs <- grouped %>% + dplyr::filter(h==h_i) + xs <- xs$x + + next_idx_to_add = which.max(h==h_i) - 1 + + while ((length(xs) < w) & (next_idx_to_add > 0)) { + # Include points from the previous horizon. All of them if still less + # than w, otherwise just enough to get to w. + xs <- c(x[next_idx_to_add], xs) + next_idx_to_add = next_idx_to_add - 1 + } + if (length(xs) < w) { + # Ran out of horizons before enough points. + break + } + res.i <- data.frame(horizon=hs[i]) + res.i[[name]] <- median(xs) + res <- rbind(res.i, res) + i <- i - 1 + } + return(res) +} + # The functions below specify performance metrics for cross-validation results. # Each takes as input the output of cross_validation, and returns the statistic # as a dataframe, given a window size for rolling aggregation. @@ -418,6 +481,24 @@ mape <- function(df, w) { return(rolling_mean_by_h(x = ape, h = df$horizon, w = w, name = 'mape')) } + +#' Median absolute percent error +#' +#' @param df Cross-validation results dataframe. +#' @param w Aggregation window size. +#' +#' @return Array of median absolute percent errors. +#' +#' @keywords internal +mdape <- function(df, w) { + ape <- abs((df$y - df$yhat) / df$y) + if (w < 0) { + return(data.frame(horizon = df$horizon, mdape = ape)) + } + return(rolling_median_by_h(x = ape, h = df$horizon, w = w, name = 'mdape')) +} + + #' Coverage #' #' @param df Cross-validation results dataframe. diff --git a/R/man/mdape.Rd b/R/man/mdape.Rd new file mode 100644 index 0000000..051e1f2 --- /dev/null +++ b/R/man/mdape.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{mdape} +\alias{mdape} +\title{Median absolute percent error} +\usage{ +mdape(df, w) +} +\arguments{ +\item{df}{Cross-validation results dataframe.} + +\item{w}{Aggregation window size.} +} +\value{ +Array of median absolute percent errors. +} +\description{ +Median absolute percent error +} +\keyword{internal} diff --git a/R/man/performance_metrics.Rd b/R/man/performance_metrics.Rd index 8bd8ed9..ef333bf 100644 --- a/R/man/performance_metrics.Rd +++ b/R/man/performance_metrics.Rd @@ -10,7 +10,7 @@ performance_metrics(df, metrics = NULL, rolling_window = 0.1) \item{df}{The dataframe returned by cross_validation.} \item{metrics}{An array of performance metrics to compute. If not provided, -will use c('mse', 'rmse', 'mae', 'mape', 'coverage').} +will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').} \item{rolling_window}{Proportion of data to use in each rolling window for computing the metrics. Should be in [0, 1] to average.} @@ -21,10 +21,11 @@ A dataframe with a column for each metric, and column 'horizon'. \description{ Computes a suite of performance metrics on the output of cross-validation. By default the following metrics are included: -'mse': mean squared error -'rmse': root mean squared error -'mae': mean absolute error -'mape': mean percent error +'mse': mean squared error, +'rmse': root mean squared error, +'mae': mean absolute error, +'mape': mean percent error, +'mdape': median percent error, 'coverage': coverage of the upper and lower intervals } \details{ diff --git a/R/man/rolling_median_by_h.Rd b/R/man/rolling_median_by_h.Rd new file mode 100644 index 0000000..8c50c5a --- /dev/null +++ b/R/man/rolling_median_by_h.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{rolling_median_by_h} +\alias{rolling_median_by_h} +\title{Compute a rolling median of x, after first aggregating by h} +\usage{ +rolling_median_by_h(x, h, w, name) +} +\arguments{ +\item{x}{Array.} + +\item{h}{Array of horizon for each value in x.} + +\item{w}{Integer window size (number of elements).} + +\item{name}{String name for metric in result dataframe.} +} +\value{ +Dataframe with columns horizon and name, the rolling median of x. +} +\description{ +Right-aligned. Computes a single median for each unique value of h. Each median +is over at least w samples. +} +\details{ +For each h where there are fewer than w samples, we take samples from the previous h, +} diff --git a/R/tests/testthat/test_diagnostics.R b/R/tests/testthat/test_diagnostics.R index 6fc9a01..6e2f631 100644 --- a/R/tests/testthat/test_diagnostics.R +++ b/R/tests/testthat/test_diagnostics.R @@ -150,7 +150,7 @@ test_that("performance_metrics", { expect_true(all( sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon')) )) - # Skip MAPE + # Skip MAPE and MDAPE df_cv$y[1] <- 0. df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mape')) expect_true(all( @@ -189,6 +189,33 @@ test_that("rolling_mean", { expect_equal(c(4.5), df$x) }) + +test_that("rolling_median", { + skip_if_not(Sys.getenv('R_ARCH') != '/i386') + x <- 0:9 + h <- 0:9 + df <- prophet:::rolling_median_by_h(x=x, h=h, w=1, name='x') + expect_equal(x, df$x) + expect_equal(h, df$horizon) + + df <- prophet:::rolling_median_by_h(x=x, h=h, w=4, name='x') + x.true <- x[4:10] - 1.5 + expect_equal(x.true, df$x) + expect_equal(3:9, df$horizon) + + h <- c(1., 2., 3., 4., 4., 4., 4., 4., 7., 7.) + x.true <- c(1., 5., 8.) + h.true <- c(3., 4., 7.) + df <- prophet:::rolling_median_by_h(x=x, h=h, w=3, name='x') + expect_equal(x.true, df$x) + expect_equal(h.true, df$horizon) + + df <- prophet:::rolling_median_by_h(x=x, h=h, w=10, name='x') + expect_equal(c(7.), df$horizon) + expect_equal(c(4.5), df$x) +}) + + test_that("copy", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') df <- DATA_all diff --git a/notebooks/diagnostics.ipynb b/notebooks/diagnostics.ipynb index c6934bf..fdf4dfb 100644 --- a/notebooks/diagnostics.ipynb +++ b/notebooks/diagnostics.ipynb @@ -166,9 +166,30 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a9ebee9abb44b3a97eb4df74beb6346", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, { "data": { "text/html": [ @@ -202,45 +223,45 @@ "