【问题标题】:How to pass the address of a template kernel function to a CUDA function?如何将模板内核函数的地址传递给 CUDA 函数?
【发布时间】:2013-07-10 16:56:07
【问题描述】:

我想将接受 CUDA 内核函数指针的 CUDA 运行时 API 函数与内核模板一起使用。

我可以在没有模板的情况下执行以下操作:

__global__ myKernel()
{
  ...
}

void myFunc(const char* kernel_ptr)
{
  ...
  // use API functions like
  cudaFuncGetAttributes(&attrib, kernel_ptr);
  ...
}

int main()
{
  myFunc(myKernel);
}

但是,当内核是模板时,上述方法不起作用。

另一个例子:

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>

template<typename T>
__global__ void addKernel(T *c, const T *a, const T *b)
{
    int i = threadIdx.x;
    c[i] = a[i] + b[i];
}

int main()
{
    cudaFuncAttributes attrib;
    cudaError_t err;

    //OK:
    err = cudaFuncGetAttributes(&attrib, addKernel<float>); // works fine
    printf("result: %s, reg1: %d\n", cudaGetErrorString(err), attrib.numRegs);

    //NOT OK:
    //try to get function ptr to pass as an argument:
    const char* ptr = addKernel<float>; // compile error
    err = cudaFuncGetAttributes(&attrib, ptr);
    printf("result: %s, reg2: %d\n", cudaGetErrorString(err), attrib.numRegs);
}

以上导致编译错误:

错误:没有函数模板“addKernel”的实例与 所需类型

编辑: 到目前为止,我发现的唯一解决方法是将 myFunc 中的内容(参见第一个代码示例)放入一个宏中,这很丑陋,但它不需要传递指针参数并且工作正常:

#define MY_FUNC(kernel) \
  { \
     ...\
     cudaFuncGetAttributes( &attrib, kernel ); \
     ...\
  }

用法:

MY_FUNC( myKernel<float> )

【问题讨论】:

  • 编辑了带有错误处理的代码示例,以防有人想尝试使用它。

标签: templates cuda


【解决方案1】:

addKernel&lt;void&gt;的类型不是char *,是函数类型。

相反,获取addKernel&lt;float&gt; 的地址,如下所示:

typedef void (*fun_ptr)(float*,const float *, const float*);
fun_ptr ptr = addKernel<float>; // compile error
err = cudaFuncGetAttributes(&attrib, ptr);

【讨论】:

  • 我在 Robert 的回答中添加了一条评论(因为它是列表中的第一个),这也与您的解决方案有关。
【解决方案2】:

参考“另一个示例”中包含的代码:

改变这个:

const char* ptr = addKernel<float>; // compile error

到这里:

void (*ptr)(float *, const float *, const float *) = addKernel<float>;

而且我相信它会正确编译和运行。

我不知道它在您尝试做的整体范围内是否有用。

EDIT 回答 cmets 中的问题:

一旦我从函数中“提取”了指针,我就可以将它转换为另一种类型。试试吧。例如,下面的代码也可以:

void (*ptr)(float *, const float *, const float *) = addKernel<float>;
const char *ptr1 = (char *)ptr;
err = cudaFuncGetAttributes(&attrib, ptr1);

所以,为了回答你的问题,可以将你的函数指针转换为const char*,如果你想的话,一旦你有了你的函数指针。

顺便说一句,您作为答案发布的代码在 gcc 4.1.2 和 gcc 4.4.6 上为我引发了编译错误:

$ nvcc -arch=sm_20 -O3 -o t201 t201.cu
t201.cu: In function âint main()â:
t201.cu:25: error: address of overloaded function with no contextual type information
t201.cu:29: error: address of overloaded function with no contextual type information
$

如果我删除这两行中的&amp;,也会出现错误:

$ nvcc -arch=sm_20 -O3 -o t201 t201.cu
t201.cu: In function âint main()â:
t201.cu:25: error: insufficient contextual information to determine type
t201.cu:29: error: insufficient contextual information to determine type
$

因此,就从 A 点到 B 点需要哪些步骤而言,其中一些可能取决于编译器。

【讨论】:

  • 与 Jared 的回答相同。我猜他是在我还在编辑/思考我的答案时发布的。
  • 这确实有效,但是,我仍然不明白。 cudaFuncGetAttributes 的第二个参数是 cuda_runtima_api.h 中的 const char*。 cuda_runtime.h 中还有一个模板化的版本,但在函数定义中,它只是转换为 (const char*)。那么为什么我不能将我的函数指针转换为 const char* 并将其传递给我的基金,而 cuda 似乎能够做到呢?
  • 好的,根据你们两个的解决方案,我能够找到一个通用的解决方案:stackoverflow.com/questions/5277547/… 所以一个 void 函数指针可以做到:void (ptr)() = (void()())(&addKernel);
  • 我接受了 Jared 的回答,因为我从 cmets 中扣除了他是第一个给出正确解决方案的人。如果有人能提供答案,我仍然会很高兴 CUDA 如何使用 const char* 做到这一点。
  • 现在这很奇怪:对我来说,使用 Visual Studio 2010,无论有没有 '&' 都可以使用 /Za 选项来禁用 Microsoft 特定的语言扩展。
【解决方案3】:

编辑:添加了基于 cuda 运行时和 Robert Crovella 的答案的模板化版本。

这是一个使用 void 函数指针和模板的完整工作示例。

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>

template <typename T>
__global__ void addKernel(T *c, const T *a, const T *b)
{
    int i = threadIdx.x;
    c[i] = a[i] + b[i];
}

cudaError_t func1( cudaFuncAttributes* attrib, void (*ptr)() )
{
    return cudaFuncGetAttributes(attrib, ptr);
}

cudaError_t func2( cudaFuncAttributes* attrib, const char* ptr )
{
    return cudaFuncGetAttributes(attrib, ptr);
}

template <typename T>
cudaError_t func2( cudaFuncAttributes* attrib, T ptr )
{
    return func2( attrib, (const char*) ptr);
}

int main()
{
    cudaFuncAttributes attrib;
    cudaError_t err;

    void (*ptr2)() = (void(*)())(addKernel<float>);  // OK on Visual Studio
    err = func1(&attrib, ptr2);
    printf("result: %s, reg1: %d\n", cudaGetErrorString(err), attrib.numRegs);

    err = func2(&attrib, addKernel<double> ); // OK nice and standard
    printf("result: %s, reg2: %d\n", cudaGetErrorString(err), attrib.numRegs);
}

【讨论】:

    猜你喜欢
    • 2015-10-20
    • 2019-12-07
    • 1970-01-01
    • 2014-03-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-12-25
    相关资源
    最近更新 更多