【问题标题】:How to store a boolean mask as an attribute of a Cython class?如何将布尔掩码存储为 Cython 类的属性?
【发布时间】:2020-02-03 17:54:29
【问题描述】:

我未能将布尔掩码保存为 Cython 类的属性。在实际代码中,我需要这个掩码来更有效地执行任务。下面是一个示例代码:

core.pyx

import numpy as np
cimport numpy as np

cdef class MyClass():
    cdef public np.uint8_t[:] mask # uint8 has the same data structure of a boolean array
    cdef public np.float64_t[:] data

    def __init__(self, size):
        self.data = np.random.rand(size).astype(np.float64)
        self.mask = np.zeros(size, np.uint8)

script.py

import numpy as np
import pyximport
pyximport.install(setup_args={'include_dirs': np.get_include()})

from core import MyClass

mc = MyClass(1000000)
mc.mask = np.asarray(mc.data) > 0.5 

错误

当我运行script.py 时,它成功编译了 Cython,但抛出了错误:

Traceback (most recent call last):
  File "script.py", line 8, in <module>
    mc.mask = np.asarray(mc.data) > 0.5
  File "core.pyx", line 6, in core.MyClass.mask.__set__
    cdef public np.uint8_t[:] mask
ValueError: Does not understand character buffer dtype format string ('?')

解决方法

我目前的解决方法是将掩码传递给我需要的所有函数,例如使用cast=True

cpdef func(MyClass mc, np.ndarray[np.uint8_t, ndim=1, cast=True] mask):
    return np.asarray(mc.data)[mask]

问题

对于如何将掩码存储在 Cython 类中是否有任何想法?

【问题讨论】:

    标签: python arrays numpy boolean cython


    【解决方案1】:

    所以我不相信内存视图实际上支持布尔索引。因此,要索引数组,您总是需要这样做

    np.asarray(mc.data)[mask]
    # or
    mc.data.base[mask] # if you're sure it's always a view of something that supports boolean indexing)
    

    我认为这不会随着@ead 提到的 Cython 更新而改变。我怀疑这样做的原因是分配可能相当容易(mc.data[mask] = x),但mc.data[mask] 应该返回什么类型并不明显——它不是内存视图。

    因此,无论你做什么,都会涉及一些乱七八糟的代码。


    对于memoryview的部分Assignment可以用

    mc.mask = (np.asarray(mc.data) > 0.5).view(np.uint8)
    

    并将其返回到一个 Numpy 布尔数组:

    np.asarray(mc.mask).view(np.bool)
    

    两者都不应该涉及复制。


    如果是我设计这个,我会保持 memoryviews 非公开(仅供 Cython 使用)并具有仅保存 Python 接口的底层 Numpy 数组的普通对象属性。您可以使用property 使它们保持同步(并进行强制转换):

    cdef class MyClass:
        cdef np.uint8_t[:] mask_mview
        cdef object _mask
    
        @property
        def mask(self):
            return np.asarray(self._mask).view(np.bool)
    
        @mask.setter
        def mask(self, value):
            self._mask = value
            self.mask_view = value.view(np.uint8)
    
        # and the same for data
    

    这样你就有了一个 memoryview 用于 memoryviews 擅长的事情(在 Cython 中逐个元素快速迭代),访问 Python 的普通 Numpy 数组,并且两者保持同步(至少通过Python 接口)。

    【讨论】:

      【解决方案2】:

      您最好的选择(如果您不想使用解决方法)可能是等待 Cython 0.29.14 发布。这个问题was fixed 很可能是0.29.14 的一部分。

      以下最小示例

      %%cython
      import numpy as np
      cimport numpy as np
      cdef np.uint8_t[:] mask  = np.random.rand(20)>.5
      

      将无法正常导入

      ValueError: 不理解字符缓冲区 dtype 格式字符串 ('?')

      对于 Cython 0.29.13,但使用来自 0.29.x-branch on github(或 master)的当前状态。

      【讨论】:

      • 很高兴听到这个消息,随身携带这些面具的解决方法很烦人,谢谢
      • @SaulloG.P.Castro 如果有帮助,mc.mask = (np.asarray(mc.data) &gt; 0.5).view(np.uint8_t) 应该正确分配(作为不同的解决方法)
      • @DavidW 看起来很有希望,您能否详细说明一下新的答案?您可以使用这样存储的掩码进行精美的索引吗?
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-11-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多