From e665430adcd7690a1ea7565803f34043596045fe Mon Sep 17 00:00:00 2001 From: RaymondMcT Date: Tue, 24 May 2022 08:28:58 -0400 Subject: [PATCH] Improved execution time of rolling_mean_by_h (#2142) --- R/R/diagnostics.R | 59 ++++++++++++++++++++--------------- python/prophet/diagnostics.py | 50 ++++++++++++++--------------- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index 1e9aae4..ff44594 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -347,38 +347,45 @@ rolling_mean_by_h <- function(x, h, w, name) { df <- data.frame(x=x, h=h) df2 <- df %>% dplyr::group_by(h) %>% - dplyr::summarise(mean = mean(x), n = dplyr::n()) + dplyr::summarise(sum = sum(x), n = dplyr::n()) - xm <- df2$mean + xs <- df2$sum ns <- df2$n hs <- df2$h - - res <- data.frame(horizon=c()) - res[[name]] <- c() + + trailing_i <- length(hs) + x_sum <- 0 + n_sum <- 0 + # We don't know output size but it is bounded by length(hs) + res_x <- vector("double", length=length(hs)) + # Start from the right and work backwards - i <- length(hs) - while (i > 0) { - # Construct a mean of at least w samples - n <- ns[i] - xbar <- xm[i] - j <- i - 1 - while ((n < w) & (j > 0)) { - # Include points from the previous horizon. All of them if still less - # than w, otherwise just enough to get to w. - n2 <- min(w - n, ns[j]) - xbar <- xbar * (n / (n + n2)) + xm[j] * (n2 / (n + n2)) - n <- n + n2 - j <- j - 1 + for(i in length(hs):1) { + x_sum <- x_sum + xs[i] + n_sum <- n_sum + ns[i] + while (n_sum >= w) { + # Include points from the previous horizon. All of them if still + # less than w, otherwise weight the mean by the difference + excess_n <- n_sum - w + excess_x <- excess_n * xs[i]/ ns[i] + res_x[trailing_i] <- (x_sum - excess_x) / w + x_sum <- x_sum - xs[trailing_i] + n_sum <- n_sum - ns[trailing_i] + trailing_i <- trailing_i - 1 } - if (n < w) { - # Ran out of horizons before enough points. - break - } - res.i <- data.frame(horizon=hs[i]) - res.i[[name]] <- xbar - res <- rbind(res.i, res) - i <- i - 1 } + + # R handles subsetting weirdly + if(trailing_i == 0) { + res_h <- hs + } else { + res_h <- hs[-(1:trailing_i)] + res_x <- res_x[-(1:trailing_i)] + } + + res <- data.frame(horizon=res_h) + res[[name]] <- res_x + return(res) } diff --git a/python/prophet/diagnostics.py b/python/prophet/diagnostics.py index c25be97..e6445a3 100644 --- a/python/prophet/diagnostics.py +++ b/python/prophet/diagnostics.py @@ -420,37 +420,37 @@ def rolling_mean_by_h(x, h, w, name): # Aggregate over h df = pd.DataFrame({'x': x, 'h': h}) df2 = ( - df.groupby('h').agg(['mean', 'count']).reset_index().sort_values('h') + df.groupby('h').agg(['sum', 'count']).reset_index().sort_values('h') ) - xm = df2['x']['mean'].values + xs = df2['x']['sum'].values ns = df2['x']['count'].values - hs = df2['h'].values + hs = df2.h.values + + trailing_i = len(df2) - 1 + x_sum = 0 + n_sum = 0 + # We don't know output size but it is bounded by len(df2) + res_x = np.empty(len(df2)) - res_h = [] - res_x = [] # Start from the right and work backwards - i = len(hs) - 1 - while i >= 0: - # Construct a mean of at least w samples. - n = int(ns[i]) - xbar = float(xm[i]) - j = i - 1 - while ((n < w) and j >= 0): + for i in range(len(df2) - 1, -1, -1): + x_sum += xs[i] + n_sum += ns[i] + while n_sum >= w: # Include points from the previous horizon. All of them if still - # less than w, otherwise just enough to get to w. - n2 = min(w - n, ns[j]) - xbar = xbar * (n / (n + n2)) + xm[j] * (n2 / (n + n2)) - n += n2 - j -= 1 - if n < w: - # Ran out of horizons before enough points. - break - res_h.append(hs[i]) - res_x.append(xbar) - i -= 1 - res_h.reverse() - res_x.reverse() + # less than w, otherwise weight the mean by the difference + excess_n = n_sum - w + excess_x = excess_n * xs[i] / ns[i] + res_x[trailing_i] = (x_sum - excess_x)/ w + x_sum -= xs[trailing_i] + n_sum -= ns[trailing_i] + trailing_i -= 1 + + res_h = hs[(trailing_i + 1):] + res_x = res_x[(trailing_i + 1):] + return pd.DataFrame({'horizon': res_h, name: res_x}) + def rolling_median_by_h(x, h, w, name):