【问题标题】:Efficiently writing code instead of using a for loop高效地编写代码而不是使用 for 循环
【发布时间】:2017-08-16 13:56:23
【问题描述】:

以下是一个玩具示例。实际上,我将模拟 6000 次 Monte Carlo 复制的数据,并为每个复制计算 St,并且在每个复制中,l 的长度会很大。如何有效地编写代码以减少运行时间。

time <- c(6,6,6,6,7,9,10,10,11,13,16,17,19,20,22,23,25,32,32,34,35)
cens <- c(1,1,1,0,1,0,1,0,0,1,1,0,0,0,1,1,0,0,0,0,0)
l <- length(time)

n <- NULL
d <- NULL
St <- NULL

n[1] <- sum(time[1]<=time)
d[1] <- sum(time==time[1] & cens==1)
St[1] <- (n[1]-d[1])/n[1]

for(i in 2:l){

   if(time[i]==time[i-1]){

      n[i]  <- n[i-1]
      d[i]  <- d[i-1]
      St[i] <- St[i-1]

 } else{

    n[i] <- sum(time[i]<=time)
    d[i] <- sum(time==time[i] & cens==1)
    St[i] <- St[i-1] * ((n[i]-d[i])/n[i])
 }

 }# end of for loop

 fit <- data.frame(ti=time, ni=n, di=d, St )

编辑:生成数据

set.seed(5)

l <- 500
time <- round(runif(l,3,38))
cens <- round(runif(l,0,1))

n <- NULL
d <- NULL
St <- NULL

n[1] <- sum(time[1]<=time)
d[1] <- sum(time==time[1] & cens==1)
St[1] <- (n[1]-d[1])/n[1]

for(i in 2:l){

   if(time[i]==time[i-1]){

      n[i]  <- n[i-1]
      d[i]  <- d[i-1]
      St[i] <- St[i-1]

 } else{

    n[i] <- sum(time[i]<=time)
    d[i] <- sum(time==time[i] & cens==1)
    St[i] <- St[i-1] * ((n[i]-d[i])/n[i])
 }

 }# end of for loop

 fit <- data.frame(ti=time, ni=n, di=d, St )

【问题讨论】:

  • 一方面预先分配 n、d 和 St、n &lt;- numeric(l) 等等。这样,您不会在每个循环中复制这些向量以填充添加下一个条目,内存已经分配。
  • @lmo 我相信 JIT 字节码编译器现在主要处理这个问题。
  • @Roland 真的吗?这太棒了。是否知道任何描述此行为的文档?
  • @lmo 我不确定我是否在某个地方读到过这篇文章(可能是新闻),但我认为this talk 中提到了它,当我最近进行基准测试时,增长并不比预分配慢。跨度>
  • @Roland 感谢您的链接。有机会会去看看。

标签: r


【解决方案1】:

您应该避免循环并尽可能在编译(矢量化)代码中执行。以下应该是相当快的,因为它是矢量化的:

library(data.table)
library(zoo)

DT <- data.table(time, cens)

#sum cens by time, this is why I use data.table but you could also use dplyr
DT[, d := sum(cens == 1L), by = time]

#calculate n and St
DT[, c("n", "St") := {
 #find time changes
 dn <- c(TRUE, diff(time) > 0)
 #calculate remaining length for time changing points
 nt <- length(time) - which(dn) + 1
 #vector of NA values
 n <- rep(NA, length(time))
 #fill in nt values
 n[dn] <- nt
 #vector of NA values
 St <- rep(NA, length(time))
 #fill in St values for time change points
 St[dn] <- cumprod(((n - d) / n)[dn])
 #last observation carried forward
 list(na.locf(n), na.locf(St))
}]
#     time cens d  n        St
#  1:    6    1 3 21 0.8571429
#  2:    6    1 3 21 0.8571429
#  3:    6    1 3 21 0.8571429
#  4:    6    0 3 21 0.8571429
#  5:    7    1 1 17 0.8067227
#  6:    9    0 0 16 0.8067227
#  7:   10    1 1 15 0.7529412
#  8:   10    0 1 15 0.7529412
#  9:   11    0 0 13 0.7529412
# 10:   13    1 1 12 0.6901961
# 11:   16    1 1 11 0.6274510
# 12:   17    0 0 10 0.6274510
# 13:   19    0 0  9 0.6274510
# 14:   20    0 0  8 0.6274510
# 15:   22    1 1  7 0.5378151
# 16:   23    1 1  6 0.4481793
# 17:   25    0 0  5 0.4481793
# 18:   32    0 0  4 0.4481793
# 19:   32    0 0  4 0.4481793
# 20:   34    0 0  2 0.4481793
# 21:   35    0 0  1 0.4481793
#     time cens d  n        St

【讨论】:

  • DT[, d := sum(cens == 1L), by = time],为什么它是1L而不是1
  • 为了清楚地表明我们正在比较整数值。永远不要与浮点数进行精确比较。
【解决方案2】:
myF <- function(x) {
  require(data.table)

  d <- data.table(time, cens)

  tt <- function(x, y)sum(x <= y)
  tt <- Vectorize(tt, vectorize.args = "x")
  d[, nnew := tt(time, time)]

  tt2 <- function(x, y, cens)sum(x == y & cens == 1)
  tt2 <- Vectorize(tt2, vectorize.args = "x")
  d[, dnew := tt2(time, time, cens)]

  d[, multi := (nnew - dnew) / nnew]
  d[duplicated(time), multi := 1]
  d[, St := cumprod(multi)]
  d[, multi := NULL][, cens := NULL]
  setnames(d, "time", "ti")
  setnames(d, "nnew", "ni")
  setnames(d, "dnew", "di")
  d[]
}
> myF()
    ti ni di        St
 1:  6 21  3 0.8571429
 2:  6 21  3 0.8571429
 3:  6 21  3 0.8571429
 4:  6 21  3 0.8571429
 5:  7 17  1 0.8067227
 6:  9 16  0 0.8067227
 7: 10 15  1 0.7529412
 8: 10 15  1 0.7529412
 9: 11 13  0 0.7529412
10: 13 12  1 0.6901961
11: 16 11  1 0.6274510
12: 17 10  0 0.6274510
13: 19  9  0 0.6274510
14: 20  8  0 0.6274510
15: 22  7  1 0.5378151
16: 23  6  1 0.4481793
17: 25  5  0 0.4481793
18: 32  4  0 0.4481793
19: 32  4  0 0.4481793
20: 34  2  0 0.4481793
21: 35  1  0 0.4481793
    ti ni di        St
> all.equal(fit, as.data.frame(myF()))
[1] TRUE

【讨论】:

    猜你喜欢
    • 2013-02-16
    • 2012-04-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-04-12
    • 2016-07-05
    相关资源
    最近更新 更多