【问题标题】:R - Faster sorting of matrix productR - 矩阵产品的更快排序
【发布时间】:2017-06-20 22:49:19
【问题描述】:

有 2 个矩阵 A 和 B。A 的大小为 2M*50,B 的大小为 20k*50。我想为每一行计算 A %*% t(B) 的前 10 个值。 我想知道是否有比这个更快的实现

library(parallel)
library(pbapply)

set.seed(1)

A <- matrix(runif(2e6*50), nrow=2e6)
B <- matrix(runif(2e4*50), nrow=2e4)
n <- 10

cl = makeCluster(detectCores())

clusterExport(cl, c("A","B", "n"))

Z <- pbsapply(1:nrow(A), function(x){
  score = A[x,] %*% t(B)
  nth_score = -sort(-score, partial=n)[n]
  top_scores_1 = which(score > nth_score)
  top_scores_2 = which(score == nth_score)
  if (!length(top_scores_2) == 1)  top_scores_2 = sample(top_scores_2, n - length(top_scores_1))
  top_scores = c(top_scores_1, top_scores_2)
  top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix
  return(top_scores[top_ix])
}, cl = cl)

stopCluster(cl)

【问题讨论】:

    标签: r sorting matrix


    【解决方案1】:

    快速改进,将A %*% t(B) 替换为tcrossprod(A,B)

    n <- 10
    
    func1 <- function(x){
      score = A[x,] %*% t(B)
      nth_score = -sort(-score, partial=n)[n]
      top_scores_1 = which(score > nth_score)
      top_scores_2 = which(score == nth_score)
      if (!length(top_scores_2) == 1)  top_scores_2 = sample(top_scores_2, n - length(top_scores_1))
      top_scores = c(top_scores_1, top_scores_2)
      top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix
      return(top_scores[top_ix])
    }
    
    func2 <- function(x){
      score = tcrossprod(A[x,],B)
      nth_score = -sort(-score, partial=n)[n]
      top_scores_1 = which(score > nth_score)
      top_scores_2 = which(score == nth_score)
      if (!length(top_scores_2) == 1)  top_scores_2 = sample(top_scores_2, n - length(top_scores_1))
      top_scores = c(top_scores_1, top_scores_2)
      top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix
      return(top_scores[top_ix])
    }
    
    all.equal(func1(1),func2(1))
    # TRUE
    
    microbenchmark(func1(1),func2(1))
    # Unit: milliseconds
    # expr      min       lq     mean   median        uq       max neval
    # func1(1) 6.527077 9.254476 9.757431 9.726585 10.311310 11.932170   100
    # func2(1) 3.365654 3.721711 4.036532 3.998387  4.246175  5.405226   100
    

    【讨论】:

      猜你喜欢
      • 2022-01-12
      • 2018-03-15
      • 2016-06-17
      • 2020-04-29
      • 1970-01-01
      • 1970-01-01
      • 2015-07-24
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多