【问题标题】:Python type checking numpy arrays include their dtypePython类型检查numpy数组包括它们的dtype
【发布时间】:2021-05-02 15:02:00
【问题描述】:

我可以使用以下方法验证我的函数接收到正确类型的输入:

def foo(x: np.ndarray, y: float):
    return x * y

确保如果我尝试将此函数与不是 np.ndarrayx 一起使用,我什至会在运行代码之前收到错误。

我不知道的是如何验证数组类型。例如:

 def return_valid_points_only(points: np.ndarray, valid: np.ndarray):
    assert points.shape == valid.shape
    return points[valid]

我想检查valid 不仅是np.ndarray,而且还是valid.dtype == bool

对于这个例子,如果 valid 将提供 0 和 1 来表示有效性,程序不会失败,我会得到可怕的结果。

谢谢

【问题讨论】:

  • 尝试 valid.dtype == bool?
  • 你用什么来检查类型? mypy? IDE?
  • edit - 抱歉,IDE 是自己做的
  • 这些类型检查是为了让其他程序员更容易理解函数+如果它错误地使用它(从错误类型发送函数参数)Pychrm 让它当场知道
  • 我想我们依赖 PyCharm 的功能来满足您的要求,而不是 Python。

标签: python numpy static-typing


【解决方案1】:

Python 的全部意义在于请求宽恕,而不是许可。这意味着即使在您的第一个定义中,def foo(x: np.ndarray, y: float): 也确实依赖用户来遵守提示,除非您使用的是 mypy 之类的东西。

您可以在此处采用多种方法,通常是同时使用。一种是以与传入的输入一起工作的方式编写函数,这可能意味着失败或强制无效输入。另一种方法是仔细记录您的代码,以便用户做出明智的决定。第二种方法尤为重要,但我将重点介绍第一种。

Numpy 会为您完成大部分检查。例如,与其期望一个数组,不如强制一个数组:

x = np.asanyarray(x)

np.asanyarray 通常是array(a, dtype, copy=False, order=order, subok=True) 的别名。你可以为y做类似的事情:

y = np.asanyarray(y).item()

这将允许任何类似数组,只要它有一个元素,无论是否是标量。另一种方法是尊重 numpy 将数组一起广播的能力,因此如果用户将y 作为x.shape[-1] 元素的列表传入。

对于您的第二个功能,您有几个选择。一种选择是允许花哨的索引。因此,如果用户传入索引列表与布尔掩码,您可以同时使用两者。另一方面,如果您坚持使用布尔掩码,则可以检查或强制 dtype。

如果您检查,请记住,如果数组大小不匹配,numpy 索引操作将为您引发错误。您只需要检查类型本身:

points = np.asanyarray(points)
valid = np.asanyarray(valid)
if valid.dtype != bool:
    raise ValueError('valid argument must be a boolean mask')

如果您选择强制,则允许用户使用零和一,但不会不必要地复制有效输入:

valid = np.asanyarray(valid, bool)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2018-09-19
    • 1970-01-01
    • 2019-06-27
    • 2019-11-18
    • 2020-07-02
    • 2012-06-03
    • 1970-01-01
    • 2017-11-14
    相关资源
    最近更新 更多