【问题标题】:Rcpp function to find the median, given a vector of values and their frequencies给定值向量及其频率的 Rcpp 函数查找中值
【发布时间】:2016-03-20 05:09:07
【问题描述】:

我正在编写一个函数来查找一组值的中位数。数据显示为唯一值的向量(称为“值”)和频率向量(“频率”)。通常频率非常高,因此将它们粘贴出来会占用大量内存。我有一个缓慢的 R 实现,它是我代码中的主要瓶颈,所以我正在编写一个自定义 Rcpp 函数以在 R/Bioconductor 包中使用。 Bioconductor 的网站建议不要使用 C++11,所以这对我来说是个问题。

我的问题在于尝试根据值的顺序将两个向量排序在一起。在 R 中,我们可以只使用 order() 函数。尽管遵循了有关此问题的建议,但我似乎无法使其正常工作:C++ sorting and keeping track of indexes

以下几行是问题所在:

   // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

这里是完整的功能,任何人都感兴趣。任何进一步的提示将不胜感激:

    #include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
    int len = freqs.size();
    if (any(freqs!=0)){
        int med = 0;
        return med;
    }
    // filter out the zeros pre-sorting
    IntegerVector non_zeros;
    for (int i = 0; i < len; i++){
        if(freqs[i] != 0){
            non_zeros.push_back(i);
        }
    }
    freqs = freqs[non_zeros];
    values = values[non_zeros];
    // find the order of values
    // create integer vector of indices
    IntegerVector idx(len);
    for (int i = 0; i < len; ++i) idx[i] = i;

    // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

    //apply to freqs and values
    freqs = freqs[idx_ord];
    values=values[idx_ord];
    IntegerVector cum_freqs(len);
    cum_freqs[0] = freqs[0];
    for (int i = 1; i < len; ++i) cum_freqs[i] = freqs[i] + cum_freqs[i-1];
    int total_freqs = cum_freqs[len-1];
    // split into odd and even frequencies and calculate the median
    if (total_freqs % 2 == 1) {
        int med_ind = (total_freqs + 1)/2 - 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind){
            i++;
        }
        double ret = values[i];
        return ret;
    } else {
        int med_ind_1 = total_freqs/2 - 1; // C++ indexes from 0
        int med_ind_2 = med_ind_1 + 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_1){
            i++;
        }
        double ret_1 = values[i];
        i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_2){
            i++;
        }
        double ret_2 = values[i];
        double ret = (ret_1 + ret_2)/2;
        return ret;
    }
}

对于任何使用 RUnit 测试框架的人,这里有一些基本的单元测试:

test_median_freq <- function(){
    checkEquals(median_freq(1:10,1:10),7)
    checkEquals(median_freq(1:10,rep(1,10)),5.5)
    checkEquals(median_freq(2:6,c(1,2,1,45,2)),5)
}

谢谢!

【问题讨论】:

    标签: c++ r sorting rcpp


    【解决方案1】:

    我实际上会将值和频率组合成std::pair&lt;double, int&gt;,然后使用std::sort 对它们进行排序;这样,您始终可以将值及其频率保持在一起。这使您可以编写更简洁的代码,因为没有额外的一组索引浮动:

    #include <Rcpp.h>
    using namespace Rcpp;
    
    // [[Rcpp::export]]
    double median_freq(NumericVector values, IntegerVector freqs) {
      const int len = freqs.size();
      std::vector<std::pair<double, int> > allDat;
      int freqSum = 0;
      for (int i=0; i < len; ++i) {
        allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
        freqSum += freqs[i];
      }
      std::sort(allDat.begin(), allDat.end());
      int accum = 0;
      for (int i=0; i < len; ++i) {
        accum += allDat[i].second;
        if (freqSum % 2 == 0) {
          if (accum > freqSum / 2) {
            return allDat[i].first;
          } else if (accum == freqSum / 2) {
            return (allDat[i].first + allDat[i+1].first) / 2;
          }
        } else {
          if (accum >= (freqSum+1)/2) {
            return allDat[i].first;
          }
        }
      }
      return NA_REAL;  // Should not be reached
    }
    

    在 R 中尝试一下:

    median_freq(1:10, 1:10)
    # [1] 7
    median_freq(1:10,rep(1,10))
    # [1] 5.5
    median_freq(2:6,c(1,2,1,45,2))
    # [1] 5
    

    我们还可以编写一个简单的 R 实现来确定我们从使用 Rcpp 获得的效率提升:

    med.freq.r <- function(values, freqs) {
      ord <- order(values)
      values <- values[ord]
      freqs <- freqs[ord]
      s <- sum(freqs)
      cs <- cumsum(freqs)
      idx <- min(which(cs >= s/2))
      if (s %% 2 == 0 && cs[idx] == s/2) {
        (values[idx] + values[idx+1]) / 2
      } else {
        values[idx]
      }
    }
    med.freq.r(1:10, 1:10)
    # [1] 7
    med.freq.r(1:10,rep(1,10))
    # [1] 5.5
    med.freq.r(2:6,c(1,2,1,45,2))
    # [1] 5
    

    为了进行基准测试,让我们看一组非常大的值:

    set.seed(144)
    values <- rnorm(1000000)
    freqs <- sample(1:100, 1000000, replace=TRUE)
    all.equal(median_freq(values, freqs), med.freq.r(values, freqs))
    # [1] TRUE
    library(microbenchmark)
    microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs))
    # Unit: milliseconds
    #                        expr      min       lq     mean   median       uq      max neval
    #  median_freq(values, freqs) 128.5322 131.6095 146.8360 145.6389 159.6117 165.0306    10
    #   med.freq.r(values, freqs) 715.2187 744.5709 776.0539 765.9178 817.7157 855.1898    10
    

    对于 100 万个条目,Rcpp 解决方案比 R 解决方案快约 5 倍;考虑到编译开销,只有在处理非常大的向量或者这是一个经常重复的选项时,这种性能才有吸引力。

    线性时间方法

    一般来说,我们知道如何在不排序的情况下计算中位数(有关详细信息,请查看http://www.cc.gatech.edu/~mihail/medianCMU.pdf)。虽然该算法比排序和迭代更精细一些,但它可以产生显着的加速:

    double fast_median_freq(NumericVector values, IntegerVector freqs) {
      const int len = freqs.size();
      std::vector<std::pair<double, int> > allDat;
      int freqSum = 0;
      for (int i=0; i < len; ++i) {
        allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
        freqSum += freqs[i];
      }
    
      int target = freqSum / 2;
      int low = 0;
      int high = len-1;
      while (true) {
        // Random pivot; move to the end
        int rnd = low + (rand() % (high-low+1));
        std::swap(allDat[rnd], allDat[high]);
    
        // In-place pivot
        int highPos = low;  // Start of values higher than pivot
        int lowSum = 0;  // Sum of frequencies of elements below pivot
        for (int pos=low; pos < high; ++pos) {
          if (allDat[pos].first <= allDat[high].first) {
            lowSum += allDat[pos].second;
            std::swap(allDat[highPos], allDat[pos]);
            ++highPos;
          }
        }
        std::swap(allDat[highPos], allDat[high]);  // Move pivot to "highPos"
    
        // If we found the element then return; o/w recurse on proper side
        if (lowSum >= target) {
          // Recurse on lower elements
          high = highPos - 1;
        } else if (lowSum + allDat[highPos].second >= target) {
          // Return
          if (target < lowSum + allDat[highPos].second || freqSum % 2 == 1) {
            return allDat[highPos].first;
          } else {
            double nextHighest = std::min_element(allDat.begin() + highPos+1, allDat.begin() + len-1)->first;
            return (allDat[highPos].first + nextHighest) / 2;
          }
        } else {
          // Recurse on higher elements
          low = highPos + 1;
          target -= (lowSum + allDat[highPos].second);
        }
      }
    }
    

    基准测试:

    all.equal(median_freq(values, freqs), fast_median_freq(values, freqs))
    [1] TRUE
    microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs), fast_median_freq(values, freqs), times=10)
    # Unit: milliseconds
    #                             expr       min        lq      mean    median        uq       max neval
    #       median_freq(values, freqs) 119.57989 122.48622 130.47841 130.48811 132.75421 146.36136    10
    #        med.freq.r(values, freqs) 665.72803 690.15016 708.05729 702.65885 731.83936 749.36834    10
    #  fast_median_freq(values, freqs)  24.37572  29.39641  31.86144  31.77459  34.88418  36.81606    10
    

    线性方法的速度比 sort-then-iterate Rcpp 解决方案快 4 倍,比基本 R 解决方案快 20 倍。

    【讨论】:

    • 这很漂亮。我认为 sort 将按第一个参数排序并忽略频率。我现在无法对其进行测试,但会在可以时将其标记为已回答。谢谢。
    • @Tom 是的,它按配对的第一个元素排序,如果出现平局,它将按配对的第二个元素排序(请参阅en.cppreference.com/w/cpp/utility/pair/operator_cmp)。
    • 这篇论文太酷了!再次感谢您的帮助,这真的很棒。
    • @Tom 如果此答案解决了您的问题,请考虑接受和/或投票。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-05-29
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多