Add mdape performance metric to R (#1472)

* add test and initial function for mdape in R

* Add separate rolling_median_func and tests

* Modify rolling median function

* fix syntax in rolling median function

* sort by h

* R/diagnostics.R

* update .rd docs and notebook

* Add mdape to performance metrics params docstring
This commit is contained in:
Ryan Nazareth 2020-05-20 20:28:50 +01:00 committed by GitHub
parent 16e632a695
commit f16d9df333
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 211 additions and 34 deletions

View file

@ -216,10 +216,11 @@ prophet_copy <- function(m, cutoff = NULL) {
#'
#' Computes a suite of performance metrics on the output of cross-validation.
#' By default the following metrics are included:
#' 'mse': mean squared error
#' 'rmse': root mean squared error
#' 'mae': mean absolute error
#' 'mape': mean percent error
#' 'mse': mean squared error,
#' 'rmse': root mean squared error,
#' 'mae': mean absolute error,
#' 'mape': mean percent error,
#' 'mdape': median percent error,
#' 'coverage': coverage of the upper and lower intervals
#'
#' A subset of these can be specified by passing a list of names as the
@ -244,7 +245,7 @@ prophet_copy <- function(m, cutoff = NULL) {
#'
#' @param df The dataframe returned by cross_validation.
#' @param metrics An array of performance metrics to compute. If not provided,
#' will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
#' will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').
#' @param rolling_window Proportion of data to use in each rolling window for
#' computing the metrics. Should be in [0, 1] to average.
#'
@ -275,6 +276,10 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
message('Skipping MAPE because y close to 0')
metrics <- metrics[metrics != 'mape']
}
if (('mdape' %in% metrics) & (min(abs(df_m$y)) < 1e-8)) {
message('Skipping MDAPE because y close to 0')
metrics <- metrics[metrics != 'mdape']
}
if (length(metrics) == 0) {
return(NULL)
}
@ -351,6 +356,64 @@ rolling_mean_by_h <- function(x, h, w, name) {
return(res)
}
#' Compute a rolling median of x, after first aggregating by h
#'
#' Right-aligned. Computes a single median for each unique value of h. Each median
#' is over at least w samples.
#'
#' For each h where there are fewer than w samples, we take samples from the previous h,
# moving backwards. (In other words, we ~ assume that the x's are shuffled within each h.)
#'
#' @param x Array.
#' @param h Array of horizon for each value in x.
#' @param w Integer window size (number of elements).
#' @param name String name for metric in result dataframe.
#'
#' @return Dataframe with columns horizon and name, the rolling median of x.
#'
#' @importFrom dplyr "%>%"
rolling_median_by_h <- function(x, h, w, name) {
# Aggregate over h
df <- data.frame(x=x, h=h)
grouped <- df %>% dplyr::group_by(h)
df2 <- grouped %>%
dplyr::summarise(size=dplyr::n()) %>%
dplyr::arrange(h) %>%
dplyr::select(h, size)
hs <- df2$h
res <- data.frame(horizon=c())
res[[name]] <- c()
# Start from the right and work backwards
i <- length(hs)
while (i > 0) {
h_i <- hs[i]
xs <- grouped %>%
dplyr::filter(h==h_i)
xs <- xs$x
next_idx_to_add = which.max(h==h_i) - 1
while ((length(xs) < w) & (next_idx_to_add > 0)) {
# Include points from the previous horizon. All of them if still less
# than w, otherwise just enough to get to w.
xs <- c(x[next_idx_to_add], xs)
next_idx_to_add = next_idx_to_add - 1
}
if (length(xs) < w) {
# Ran out of horizons before enough points.
break
}
res.i <- data.frame(horizon=hs[i])
res.i[[name]] <- median(xs)
res <- rbind(res.i, res)
i <- i - 1
}
return(res)
}
# The functions below specify performance metrics for cross-validation results.
# Each takes as input the output of cross_validation, and returns the statistic
# as a dataframe, given a window size for rolling aggregation.
@ -418,6 +481,24 @@ mape <- function(df, w) {
return(rolling_mean_by_h(x = ape, h = df$horizon, w = w, name = 'mape'))
}
#' Median absolute percent error
#'
#' @param df Cross-validation results dataframe.
#' @param w Aggregation window size.
#'
#' @return Array of median absolute percent errors.
#'
#' @keywords internal
mdape <- function(df, w) {
ape <- abs((df$y - df$yhat) / df$y)
if (w < 0) {
return(data.frame(horizon = df$horizon, mdape = ape))
}
return(rolling_median_by_h(x = ape, h = df$horizon, w = w, name = 'mdape'))
}
#' Coverage
#'
#' @param df Cross-validation results dataframe.

20
R/man/mdape.Rd Normal file
View file

@ -0,0 +1,20 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{mdape}
\alias{mdape}
\title{Median absolute percent error}
\usage{
mdape(df, w)
}
\arguments{
\item{df}{Cross-validation results dataframe.}
\item{w}{Aggregation window size.}
}
\value{
Array of median absolute percent errors.
}
\description{
Median absolute percent error
}
\keyword{internal}

View file

@ -10,7 +10,7 @@ performance_metrics(df, metrics = NULL, rolling_window = 0.1)
\item{df}{The dataframe returned by cross_validation.}
\item{metrics}{An array of performance metrics to compute. If not provided,
will use c('mse', 'rmse', 'mae', 'mape', 'coverage').}
will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').}
\item{rolling_window}{Proportion of data to use in each rolling window for
computing the metrics. Should be in [0, 1] to average.}
@ -21,10 +21,11 @@ A dataframe with a column for each metric, and column 'horizon'.
\description{
Computes a suite of performance metrics on the output of cross-validation.
By default the following metrics are included:
'mse': mean squared error
'rmse': root mean squared error
'mae': mean absolute error
'mape': mean percent error
'mse': mean squared error,
'rmse': root mean squared error,
'mae': mean absolute error,
'mape': mean percent error,
'mdape': median percent error,
'coverage': coverage of the upper and lower intervals
}
\details{

View file

@ -0,0 +1,27 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{rolling_median_by_h}
\alias{rolling_median_by_h}
\title{Compute a rolling median of x, after first aggregating by h}
\usage{
rolling_median_by_h(x, h, w, name)
}
\arguments{
\item{x}{Array.}
\item{h}{Array of horizon for each value in x.}
\item{w}{Integer window size (number of elements).}
\item{name}{String name for metric in result dataframe.}
}
\value{
Dataframe with columns horizon and name, the rolling median of x.
}
\description{
Right-aligned. Computes a single median for each unique value of h. Each median
is over at least w samples.
}
\details{
For each h where there are fewer than w samples, we take samples from the previous h,
}

View file

@ -150,7 +150,7 @@ test_that("performance_metrics", {
expect_true(all(
sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
))
# Skip MAPE
# Skip MAPE and MDAPE
df_cv$y[1] <- 0.
df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mape'))
expect_true(all(
@ -189,6 +189,33 @@ test_that("rolling_mean", {
expect_equal(c(4.5), df$x)
})
test_that("rolling_median", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
x <- 0:9
h <- 0:9
df <- prophet:::rolling_median_by_h(x=x, h=h, w=1, name='x')
expect_equal(x, df$x)
expect_equal(h, df$horizon)
df <- prophet:::rolling_median_by_h(x=x, h=h, w=4, name='x')
x.true <- x[4:10] - 1.5
expect_equal(x.true, df$x)
expect_equal(3:9, df$horizon)
h <- c(1., 2., 3., 4., 4., 4., 4., 4., 7., 7.)
x.true <- c(1., 5., 8.)
h.true <- c(3., 4., 7.)
df <- prophet:::rolling_median_by_h(x=x, h=h, w=3, name='x')
expect_equal(x.true, df$x)
expect_equal(h.true, df$horizon)
df <- prophet:::rolling_median_by_h(x=x, h=h, w=10, name='x')
expect_equal(c(7.), df$horizon)
expect_equal(c(4.5), df$x)
})
test_that("copy", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA_all

View file

@ -166,9 +166,30 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a9ebee9abb44b3a97eb4df74beb6346",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
@ -202,45 +223,45 @@
" <tr>\n",
" <th>0</th>\n",
" <td>2010-02-16</td>\n",
" <td>8.956572</td>\n",
" <td>8.460049</td>\n",
" <td>9.460400</td>\n",
" <td>8.957284</td>\n",
" <td>8.480761</td>\n",
" <td>9.415366</td>\n",
" <td>8.242493</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2010-02-17</td>\n",
" <td>8.723004</td>\n",
" <td>8.200557</td>\n",
" <td>9.236561</td>\n",
" <td>8.723736</td>\n",
" <td>8.206191</td>\n",
" <td>9.234075</td>\n",
" <td>8.008033</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2010-02-18</td>\n",
" <td>8.606823</td>\n",
" <td>8.070835</td>\n",
" <td>9.123754</td>\n",
" <td>8.607496</td>\n",
" <td>8.112153</td>\n",
" <td>9.092314</td>\n",
" <td>8.045268</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2010-02-19</td>\n",
" <td>8.528688</td>\n",
" <td>8.034782</td>\n",
" <td>9.042712</td>\n",
" <td>8.529364</td>\n",
" <td>8.017767</td>\n",
" <td>9.013877</td>\n",
" <td>7.928766</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2010-02-20</td>\n",
" <td>8.270706</td>\n",
" <td>7.754891</td>\n",
" <td>8.739012</td>\n",
" <td>8.271329</td>\n",
" <td>7.751250</td>\n",
" <td>8.775341</td>\n",
" <td>7.745003</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
@ -250,14 +271,14 @@
],
"text/plain": [
" ds yhat yhat_lower yhat_upper y cutoff\n",
"0 2010-02-16 8.956572 8.460049 9.460400 8.242493 2010-02-15\n",
"1 2010-02-17 8.723004 8.200557 9.236561 8.008033 2010-02-15\n",
"2 2010-02-18 8.606823 8.070835 9.123754 8.045268 2010-02-15\n",
"3 2010-02-19 8.528688 8.034782 9.042712 7.928766 2010-02-15\n",
"4 2010-02-20 8.270706 7.754891 8.739012 7.745003 2010-02-15"
"0 2010-02-16 8.957284 8.480761 9.415366 8.242493 2010-02-15\n",
"1 2010-02-17 8.723736 8.206191 9.234075 8.008033 2010-02-15\n",
"2 2010-02-18 8.607496 8.112153 9.092314 8.045268 2010-02-15\n",
"3 2010-02-19 8.529364 8.017767 9.013877 7.928766 2010-02-15\n",
"4 2010-02-20 8.271329 7.751250 8.775341 7.745003 2010-02-15"
]
},
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@ -313,7 +334,7 @@
" parallel=\"dask\")\n",
"```\n",
"\n",
"The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument."
"The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), median absolute percent error (MDAPE) and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument."
]
},
{