【问题标题】:Vectorising a for loop containing a which statement and a function向量化包含 which 语句和函数的 for 循环
【发布时间】:2015-05-06 02:27:48
【问题描述】:

我正在尝试矢量化的代码的可重现示例。

cutOffs <- seq(1,10,0.2)

plotOutput <- matrix(nrow=length(cutOffs), ncol=2)
colnames(plotOutput) <- c("x","y")
plotOutput[,"y"] <- cutOffs

for(plotPoint in 1:length(cutOffs))
{
  plotOutput[plotPoint, "x"] <-
    nrow(iris[ which(iris$Sepal.Length > cutOffs[plotPoint] &
                   iris$Sepal.Width > cutOffs[plotPoint]), ])
}

plotOutput

我特别想知道的是,是否有办法对这部分进行矢量化。

nrow(iris[ which(iris$Sepal.Length > cutOffs[plotPoint] &
                   iris$Sepal.Width > cutOffs[plotPoint]), ])

假设我要使用 plyr 库或某种形式的应用,可能没有太多的加速,这正是我正在寻找的。从根本上说,我想看看是否有一些我在搜索时忽略或设法错过的矢量化技术。

更新:

Unit: milliseconds
  expr         min          lq        mean      median          uq         max neval
  op() 33663.39700 33663.39700 33663.39700 33663.39700 33663.39700 33663.39700     1
  jr()  3976.53088  3976.53088  3976.53088  3976.53088  3976.53088  3976.53088     1
  dd()  4253.21050  4253.21050  4253.21050  4253.21050  4253.21050  4253.21050     1
 exp()  5085.45331  5085.45331  5085.45331  5085.45331  5085.45331  5085.45331     1
 nic()  8719.82043  8719.82043  8719.82043  8719.82043  8719.82043  8719.82043     1
  sg()    16.66177    16.66177    16.66177    16.66177    16.66177    16.66177     1

我实际在做的更现实的近似是这样的

# generate data
numObs <- 1e5
iris <- data.frame( Sepal.Length = sample(1:numObs), Sepal.Width = sample(1:numObs) )

cutOffs <- 1:(numObs*0.01)

plotOutput <- matrix(nrow=length(cutOffs), ncol=2)
colnames(plotOutput) <- c("x","y")
plotOutput[,"y"] <- cutOffs

按照人们喜欢的任何特定方法进行。

一般来说,它会用于 50,000 - 200,000 点的数据集。

与使用相比有了很大的飞跃

sum(Sepal.Length > cutOffs[plotPoint] & Sepal.Width > cutOffs[plotPoint])

这是我最初缺少的一种更优化的方法。

然而,到目前为止,最好的答案是 sgibb 的 sg()。关键是要意识到它只是重要的每一行中两个值中的最低值。一旦实现了精神上的飞跃,就只剩下一个向量需要处理,并且向量化相当简单。

# cutOff should be lower than the lowest of Sepal.Length & Sepal.Width
  m <- pmin(iris$Sepal.Length, iris$Sepal.Width)

【问题讨论】:

    标签: r vectorization


    【解决方案1】:

    我想添加另一个答案:

    sg <- function() {
      # cutOff should be lower than the lowest of Sepal.Length & Sepal.Width
      m <- pmin(iris$Sepal.Length, iris$Sepal.Width)
      ms <- sort.int(m)
      # use `findInterval` to find all the indices 
      # (equal to "how many numbers below") lower than the threshold
      plotOutput[,"x"] <- length(ms)-findInterval(cutOffs, ms)
      plotOutput
    }
    

    这种方法避免了 forouter 循环,并且比 @nicola 的方法快 4 倍:

    microbenchmark(sg(), nic(), dd())
    #Unit: microseconds
    #  expr     min       lq     mean   median       uq      max neval
    #  sg()  88.726 104.5805 127.3172 123.2895 144.2690  232.441   100
    # nic() 474.315 526.7780 625.0021 602.3685 706.7530  997.412   100
    #  dd() 669.841 736.7800 887.4873 847.7730 976.6445 2800.930   100
    
    identical(sg(), dd())
    # [1] TRUE
    

    【讨论】:

    • 确实不错findInterval (+1)。这也是我的出发点,但我惨遭失败,最终得到了一个更加复杂的cut 代码。
    【解决方案2】:

    你可以使用outer:

    plotOutput[,"x"]<-colSums(outer(1:nrow(iris),1:length(cutOffs),function(x,y) iris$Sepal.Length[x] > cutOffs[y] & iris$Sepal.Width[x] > cutOffs[y]))
    

    【讨论】:

      【解决方案3】:

      这不会删除 for 循环,但我认为它会给您一些加速 - 随意进行基准测试,让我们知道它如何与您的真实数据进行比较:

      for(i in seq_along(cutOffs)) {
        x <- cutOffs[i]
        plotOutput[i, "x"] <- with(iris, sum(Sepal.Length > x & Sepal.Width > x))
      }
      

      这是使用样本数据的小基准(可以说很小,但可能会给出一些指示):

      library(microbenchmark)
      microbenchmark(op(), jr(), dd(), exp(), nic())
      Unit: microseconds
        expr      min        lq    median        uq       max neval
        op() 6745.428 7079.8185 7378.9330 9188.0175 11936.173   100
        jr() 1335.931 1405.2030 1466.9180 1728.6595  4692.748   100
        dd()  684.786  711.6005  758.7395  923.6670  4473.725   100
       exp() 1928.083 2066.0395 2165.6985 2392.7030  5392.475   100
       nic()  383.007  402.5495  439.3835  541.6395   851.488   100
      

      基准测试中使用的函数定义如下:

      op <- function(){
        for(plotPoint in 1:length(cutOffs))
        {
          plotOutput[plotPoint, "x"] <-
            nrow(iris[ which(iris$Sepal.Length > cutOffs[plotPoint] &
                               iris$Sepal.Width > cutOffs[plotPoint]), ])
        }
        plotOutput
      }
      
      jr <- function() {
        cbind(x = sapply(cutOffs, counts), y = plotOutput[,"y"])
      }
      
      dd <- function() {
        for(i in seq_along(cutOffs)) {
          x <- cutOffs[i]
          plotOutput[i, "x"] <- with(iris, sum(Sepal.Length > x & Sepal.Width > x))
        }
        plotOutput
      }
      
      exp <- function() {
        data_frame(y=cutOffs) %>% 
          rowwise() %>% 
          mutate(x = sum(iris$Sepal.Length > y & iris$Sepal.Width > y))
      }
      
      nic <- function() {
        plotOutput[,"x"]<-colSums(outer(1:nrow(iris),1:length(cutOffs),function(x,y) iris$Sepal.Length[x] > cutOffs[y] & iris$Sepal.Width[x] > cutOffs[y]))
      }
      

      编辑说明:@nicola 包含的方法现在是最快的

      【讨论】:

      • 虽然我喜欢@nicola 的智能解决方案,但我更喜欢dd,因为outer 很长时间会占用大量内存cutOffs
      【解决方案4】:

      您可以使用dplyr

      library(dplyr)
      data_frame(y=cutOffs) %>% 
          rowwise() %>% 
          mutate(x = sum(iris$Sepal.Length > y & iris$Sepal.Width > y))
      

      【讨论】:

        【解决方案5】:

        我猜是这样的:

        counts <- function(x) sum(iris$Sepal.Length > x & iris$Sepal.Width > x ) 
        cbind(x = sapply(cutOffs, counts), y = plotOutput[,"y"])
        

        只是为了检查:

        res <- cbind(x=sapply(cutOffs,counts), y=plotOutput[,"y"])
        identical(plotOutput,res)
        [1] TRUE
        

        【讨论】:

          【解决方案6】:

          基于pmincuttable的另一种可能性

          brk <- c(cutOffs, Inf)
          rev(cumsum(rev(table(cut(pmin(iris$Sepal.Length, iris$Sepal.Width), brk)))))
          

          如果您想“从内到外”处理代码,可能会更容易使用一个较小的示例:

          set.seed(1)
          df <- data.frame(x = sample(1:10, 6), y = sample(1:10, 6))
          cutOffs <- seq(from = 2, to = 8, by = 2)
          brk <- c(cutOffs, Inf)
          
          rev(cumsum(rev(table(cut(pmin(df$x, df$y), brk)))))
          #  (2,4]   (4,6]   (6,8] (8,Inf] 
          #      4       2       1       0 
          

          即,两个值 > 2 的四行,两个值 > 4 的两行,等等

          【讨论】:

            猜你喜欢
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 2014-08-30
            • 1970-01-01
            • 1970-01-01
            • 2013-12-27
            相关资源
            最近更新 更多