【问题标题】:What is an ndarray of shape () returned after np.squeeze from shape (1,1)?np.squeeze 从形状 (1,1) 返回后的形状 () 的 ndarray 是什么?
【发布时间】:2017-08-24 13:20:41
【问题描述】:

我正在计算 NN 的成本函数。我对从 numpy.dot 获得的 (1,1) 答案执行 numpy.squeeze。然后我得到一个形状为 (0,1) 的 ndarray。

什么是形状 () 的 ndarray,形状 (1,) 的 ndarray 与形状 (5) 中的一个有何不同?

【问题讨论】:

  • 试试cost.item()cost[()]
  • 提示:您可以去掉很多括号以提高可读性:cost2 = -1.0 / m * (np.dot(Y, np.log(A).T) + np.dot(1.0 - Y, np.log(1.0 - A).T))

标签: python-2.7


【解决方案1】:
  • 形状为 (1, 1) 的 ndarray 类似于 [[3]],类似于 1x1 矩阵。
  • 形状为(1,) 的ndarray 类似于[3],类似于大小为1 的向量。
  • 形状为() 的ndarray,也称为标量,类似于3

区别很微妙,因为由于广播规则,标量和数组通常可以毫无问题地组合,但是您不能索引标量,而您可以索引大小为 1 的向量或大小为 1x1 的矩阵。另一方面,标量通常可以像原始 Python 值一样使用,例如 intfloat。如果您不想使用标量,您可以将axis 参数传递给np.squeeze 以确保某些维度不会被压缩,或者使用np.atleast_1d 以确保您传递的任何内容至少具有一个维度。您还可以使用 np.isscalar 检查某物是否为标量。

【讨论】:

    【解决方案2】:

    dot 与 (1,n) 和 (n,1) 生成 (1,1) 数组:

    In [1221]: x = np.ones((3,1))
    In [1222]: xx = np.dot(x.T,x)
    In [1223]: xx.shape
    Out[1223]: (1, 1)
    In [1224]: xx
    Out[1224]: array([[ 3.]])
    

    item 可用于从数组中提取该值:

    Out[1227]: 3.0
    In [1228]: type(_)
    Out[1228]: float
    

    您也可以通过索引选择项目,尽管type 会有所不同:

    In [1229]: xx[0,0]
    Out[1229]: 3.0
    In [1230]: type(_)
    Out[1230]: numpy.float64
    

    对于许多用途,floatnp.float64 之间的区别并不重要。

    squeeze 删除所有尺寸为 1 的尺寸。在这种情况下,结果是一个 0d 数组。 item 仍然有效。使用正确大小的索引(即长度为 0 的元组)也可以使用索引:

    In [1231]: xx0 = np.squeeze(xx)
    In [1232]: xx0.shape
    Out[1232]: ()
    In [1233]: xx0.item()
    Out[1233]: 3.0
    In [1234]: xx0[()]
    Out[1234]: 3.0
    In [1235]: type(_)
    Out[1235]: numpy.float64
    

    np.float64 的类继承是:

    In [1236]: _.__mro__
    Out[1236]: 
    (numpy.float64,
     numpy.floating,
     numpy.inexact,
     numpy.number,
     numpy.generic,
     float,
     object)
    

    所以isinstance float 仍然会返回 true

    In [1237]: isinstance(xx0.item(),float)
    Out[1237]: True
    In [1238]: isinstance(xx0[()],float)
    Out[1238]: True
    In [1239]: isinstance(xx[0,0],float)
    Out[1239]: True
    

    我不会依赖于所有 numpy dtypes。

    【讨论】:

      猜你喜欢
      • 2018-05-13
      • 2022-11-17
      • 2017-05-14
      • 1970-01-01
      • 1970-01-01
      • 2020-10-01
      • 1970-01-01
      • 2020-10-11
      • 1970-01-01
      相关资源
      最近更新 更多