【问题标题】:Transform arma::cube subview into NumericVector to use sugar将 arma::cube 子视图转换为 NumericVector 以使用糖
【发布时间】:2018-12-21 11:14:20
【问题描述】:

我将一个 3D 数组从 R 传递到 C++ 并遇到了类型转换问题。我们如何将 RcppArmadillo 中的 arma::cube subviews 转换为 NumericVectors 以使用 Rcpp 中的糖函数(如 which_min)对其进行操作?

假设您有一个带有一些数字条目的 3D 立方体 Q。我的目标是获取每行i 和每个第三维k 的列条目最小值的索引。在 R 语法中,这是which.min(Q[i,,k])

例如i = 1k = 1

cube Q = randu<cube>(3,3,3);
which_min(Q.slice(1).row(1)); // this fails

我认为转换为 NumericVector 可以解决问题,但此转换失败

which_min(as<NumericVector>(Q.slice(1).row(1))); // conversion failed

我怎样才能让它工作?感谢您的帮助。

【问题讨论】:

    标签: c++ rcpp armadillo


    【解决方案1】:

    你有几个选择:

    1. 您可以为此使用 Armadillo 函数,成员函数 .index_min()(请参阅 Armadillo 文档 here)。
    2. 您可以使用Rcpp::wrap(),其中"transforms an arbitrary object into a SEXP"arma::cube subviews 转换为Rcpp::NumericVector,并使用糖函数Rcpp::which_min()

    最初我只是将第一个选项作为答案,因为它似乎是实现目标的更直接的方法,但我添加了第二个选项(在答案的更新中),因为我现在认为任意转换可能是你好奇的部分。

    我将以下 C++ 代码放在一个文件so-answer.cpp

    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    
    // [[Rcpp::export]]
    Rcpp::List index_min_test() {
        arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
        int whichmin = Q.slice(1).row(1).index_min();
        Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                               Rcpp::Named("whichmin") = whichmin);
        return result;
    }
    
    // [[Rcpp::export]]
    Rcpp::List which_min_test() {
        arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
        Rcpp::NumericVector x = Rcpp::wrap(Q.slice(1).row(1));
        int whichmin = Rcpp::which_min(x);
        Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                               Rcpp::Named("whichmin") = whichmin);
        return result;
    }
    

    我们有一个使用 Armadillo 的 .index_min() 的函数和一个使用 Rcpp::wrap() 来启用 Rcpp::which_min() 的函数。

    然后我使用Rcpp::sourceCpp() 编译它,使函数对 R 可用,并演示使用几个不同的种子调用它们:

    Rcpp::sourceCpp("so-answer.cpp")
    set.seed(1)
    arma <- index_min_test()
    set.seed(1)
    wrap <- which_min_test()
    arma$Q[2, , 2]
    #> [1] 0.2059746 0.3841037 0.7176185
    wrap$Q[2, , 2]
    #> [1] 0.2059746 0.3841037 0.7176185
    arma$whichmin
    #> [1] 0
    wrap$whichmin
    #> [1] 0
    set.seed(2)
    arma <- index_min_test()
    set.seed(2)
    wrap <- which_min_test()
    arma$Q[2, , 2]
    #> [1] 0.5526741 0.1808201 0.9763985
    wrap$Q[2, , 2]
    #> [1] 0.5526741 0.1808201 0.9763985
    arma$whichmin
    #> [1] 1
    wrap$whichmin
    #> [1] 1
    library(microbenchmark)
    microbenchmark(arma = index_min_test(), wrap = which_min_test())
    #> Unit: microseconds
    #>  expr    min      lq     mean  median      uq    max neval cld
    #>  arma 12.981 13.7105 15.09386 14.1970 14.9920 62.907   100   a
    #>  wrap 13.636 14.3490 15.66753 14.7405 15.5415 64.189   100   a
    

    reprex package (v0.2.1) 于 2018 年 12 月 21 日创建

    【讨论】:

    • 谢谢你们!在一般情况下,转换代码是我想知道的。我不知道 .index_min() - 甜,就像一个魅力。
    • @JBJ 没问题,很高兴它有帮助!干杯
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-12-18
    • 1970-01-01
    • 1970-01-01
    • 2023-03-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多