【发布时间】:2017-09-15 13:00:14
【问题描述】:
我有一个列表 list_of_arrays 的 3D numpy 数组,我想通过模板传递给 C 函数
int my_func_c(double **data, int **shape, int n_arrays)
这样
data[i] : pointer to the numpy array values in list_of_arrays[i]
shape[i] : pointer to the shape of the array in list_of_arrays[i] e.g. [2,3,4]
如何使用 cython 接口函数调用my_func_c?
我的第一个想法是做类似下面的事情(可行),但我觉得有一种更好的方法,只使用 numpy 数组而不进行分配和释放。
# my_func_c.pyx
import numpy as np
cimport numpy as np
cimport cython
from libc.stdlib cimport malloc, free
cdef extern from "my_func.c":
double my_func_c(double **data, int **shape, int n_arrays)
def my_func(list list_of_arrays):
cdef int n_arrays = len(list_of_arrays)
cdef double **data = <double **> malloc(n_arrays*sizeof(double *))
cdef int **shape = <int **> malloc(n_arrays*sizeof(int *))
cdef double x;
cdef np.ndarray[double, ndim=3, mode="c"] temp
for i in range(n_arrays):
temp = list_of_arrays[i]
data[i] = &temp[0,0,0]
shape[i] = <int *> malloc(3*sizeof(int))
for j in range(3):
shape[i][j] = list_of_arrays[i].shape[j]
x = my_func_c(data, shape, n_arrays)
# Free memory
for i in range(n_arrays):
free(shape[i])
free(data)
free(shape)
return x
注意
要查看一个工作示例,我们可以使用一个非常简单的函数来计算列表中所有数组的乘积。
# my_func.c
double my_func_c(double **data, int **shape, int n_arrays) {
int array_idx, i0, i1, i2;
double prod = 1.0;
// Loop over all arrays
for (array_idx=0; array_idx<n_arrays; array_idx++) {
for (i0=0; i0<shape[array_idx][0]; i0++) {
for (i1=0; i1<shape[array_idx][1]; i1++) {
for (i2=0; i2<shape[array_idx][2]; i2++) {
prod = prod*data[array_idx][i0*shape[array_idx][1]*shape[array_idx][2] + i1*shape[array_idx][2] + i2];
}
}
}
}
return prod;
}
创建setup.py 文件,
# setup.py
from distutils.core import setup
from Cython.Build import cythonize
import numpy as np
setup(
name='my_func',
ext_modules = cythonize("my_func_c.pyx"),
include_dirs=[np.get_include()]
)
编译
python3 setup.py build_ext --inplace
最后我们可以运行一个简单的测试
# test.py
import numpy as np
from my_func_c import my_func
a = [1+np.random.rand(3,1,2), 1+np.random.rand(4,5,2), 1+np.random.rand(1,2,3)]
print('Numpy product: {}'.format(np.prod([i.prod() for i in a])))
print('my_func product: {}'.format(my_func(a)))
使用
python3 test.py
【问题讨论】:
-
它怎么不起作用?
-
好吧,我会写一个工作示例。
-
什么是输出
-
问题似乎不够清楚,但我不确定如何编辑它。这种情况下的输出本质上就是我们想要看到的。 cython 函数给出与 numpy 相同的结果(有一些小的舍入误差)。它的随机性每次都会改变。我想要的是一种在 cython 中处理 numpy 数组列表的方法,而无需求助于 mallocing 和释放内存。
-
@rwolst 您可以将指针传递给 numpy 的内置形状数组,而不是逐个元素地复制它们。除此之外,我认为您的方法非常好。