【问题标题】:Find longest rolling sum under/equal a threshold for each group为每组查找低于/等于阈值的最长滚动总和
【发布时间】:2022-01-19 23:13:43
【问题描述】:

我正在尝试找到最有效的方法来识别值低于特定阈值的向量的最长滚动和。例如,如果我们有1:106 的阈值,那么3 是我们最长的滚动和。

我有下面的代码和我为此制作的功能,但显然非常慢。想知道是否有更有效或已经实现的算法来识别总和低于阈值的最长运行。

library(dplyr)
library(zoo)

set.seed(1)

df <- data.frame(
  g = rep(letters[1:10], each = 100),
  x = runif(1000)
)

longest_rollsum <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- rollsum(x, i, na.pad = TRUE, align = "right")
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

df %>%
  group_by(g) %>%
  summarize(longest = longest_rollsum(x, 2))
#> # A tibble: 10 × 2
#>    g     longest
#>    <chr>   <dbl>
#>  1 a           7
#>  2 b           6
#>  3 c           9
#>  4 d           6
#>  5 e           6
#>  6 f           8
#>  7 g           6
#>  8 h           7
#>  9 i          11
#> 10 j           9

【问题讨论】:

  • 您确定您的预期输出正确吗?根据您的样本数据,g = "a" 的前 3 个条目的 cumsum ≤ 2。Wy 是该字母 7 而不是 3 的 longest
  • 你可以在 Rcpp 中写一个短路函数来提高速度。例如,对于 1:10,您从 k=1 开始,因为第一个小于 6,您放弃其余的并检查 k=2,因为 1+2 小于 6,您跳过其余的并检查 k= 3 和 1:3 是 >=6,然后做下一个,直到所有向量都检查无误

标签: r rolling-sum


【解决方案1】:

你可以写一个更快的 R 函数:

longest_rollsum_R <- function(x, threshold) {
  R_rollsum <- function(k){
    for(i in seq(length(x) - k))
      if(sum(x[i:(i+k)]) < threshold) return(FALSE)
    TRUE
  }
  for (i in 1:length(x)) if(R_rollsum(i)) return(i)    
}

现在把这个和上面的比较一下,

Unit: milliseconds
             expr       min        lq      mean    median        uq       max neval
      rollsum(df) 23.440017 25.407785 29.007619 27.906883 31.014434 45.516888   100
 rollsum_rcpp(df)  4.046400  4.499688  5.253406  4.734718  5.596438 14.079618   100
   rollsum_R1(df)  3.798058  4.194639  5.100568  4.710468  5.280267 14.749520   100

变化似乎不大。但是当我们改变阈值时它是相当大的:

Unit: milliseconds
                 expr        min         lq       mean    median         uq       max neval
      rollsum(df, 20) 111.336885 130.055676 142.683567 138.11347 147.231358 306.06438   100
 rollsum_rcpp(df, 20)  11.640328  13.170309  15.166030  14.03039  16.060333  31.23998   100
   rollsum_R1(df, 20)   5.993384   7.128607   8.125868   7.54488   8.140206  19.86842   100

您也可以在 Rcpp 中编写自己的代码来解决短路问题,这比目前给出的两种方法更快。

在您的工作目录中将以下内容另存为longest_rollsum.cpp

#include <Rcpp.h>
using namespace Rcpp;


// [[Rcpp::export]]

int longest_rollsum_C(NumericVector x, double threshold){
  auto rollsum_t = [&x, threshold](int k) {
    for (int i = 0; i< x.length() - k; i++) 
      if(sum(x[seq(i, i+k)]) < threshold) return false;
    return true;
   };
  for (int i = 0; i<x.length(); i++) if(rollsum_t(i)) return i;
  return 0;
}

在 R 中:获取上述文件

Rcpp::sourceCpp("longest_rollsum.cpp")
rollsum_R <- function(df){
  df %>% 
     group_by(g) %>%
    summarise(longest = longest_rollsum_C(x, 2))
}

microbenchmark::microbenchmark(rollsum(df), rollsum_rcpp(df), rollsum_R(df))
Unit: milliseconds
             expr       min        lq      mean    median        uq      max neval
      rollsum(df) 24.052665 25.018864 26.985276 25.453187 27.479305 37.49629   100
 rollsum_rcpp(df)  4.077397  4.352724  4.755942  4.572804  4.902468 13.10230   100
    rollsum_R(df)  2.271907  2.529000  2.871154  2.714801  2.955849 10.62107   100

【讨论】:

  • 啊,太好了,非常感谢!
【解决方案2】:

提高性能的一种方法是使用 RcppRoll 包中的 roll_sum() 函数,而不是 zoo 包中的 rollsum()

library(tidyverse)
library(zoo)
#> 
#> Attaching package: 'zoo'
#> The following objects are masked from 'package:base':
#> 
#>     as.Date, as.Date.numeric

set.seed(1)

df <- data.frame(
  g = rep(letters[1:10], each = 100),
  x = runif(1000)
)

longest_rollsum <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- zoo::rollsum(x, i, na.pad = TRUE, align = "right")
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

longest_rollsum_rcpp <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- RcppRoll::roll_sum(x, i, align = "right", fill=NA)
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

rollsum <- function(df){
  df %>%
    group_by(g) %>%
    summarize(longest = longest_rollsum(x, 2))
}

rollsum_rcpp <- function(df){
  df %>%
    group_by(g) %>%
    summarize(longest = longest_rollsum_rcpp(x, 2))
}

library(microbenchmark)
res <- microbenchmark(rollsum(df), rollsum_rcpp(df))
autoplot(res)
#> Coordinate system already present. Adding new coordinate system, which will replace the existing one.

all_equal(rollsum(df), rollsum_rcpp(df))
#> [1] TRUE

reprex package (v2.0.1) 于 2021-12-17 创建

【讨论】:

    猜你喜欢
    • 2020-12-23
    • 2018-04-03
    • 2021-11-19
    • 2018-03-14
    • 1970-01-01
    • 2023-03-23
    • 1970-01-01
    • 2011-11-01
    • 2021-05-22
    相关资源
    最近更新 更多