【问题标题】:What is the meaning of the following operation in numpy?numpy中下面的操作是什么意思?
【发布时间】:2015-06-09 16:04:09
【问题描述】:

我在挖一段 numpy 代码,有一行我完全看不懂:

W[:, :, None] * h[None, :, :] * diff[:, None, :]

其中 Whdiff 是 784x20、20x100 和 784x100 矩阵。乘法结果是 784x20x100 数组,但我不知道这个计算实际上是做什么的,结果是什么意思。

值得一提的是,该行来自机器学习相关代码,W对应神经网络层的权重数组,h是层激活,diff 是网络的目标和假设之间的差异(来自Sida Wang's thesis on transforming autoencoder)。

【问题讨论】:

    标签: python numpy matrix machine-learning matrix-multiplication


    【解决方案1】:

    对于 NumPy 数组,* 对应于逐元素乘法。为了使其工作,这两个数组必须是:

    • 形状相同
    • 这样一个数组可以broadcast 到另一个数组

    如果在配对每个数组的尾随维度时,每对中的长度相等或其中一个长度为 1,则可以将一个数组广播到另一个数组。

    例如,以下数组AB 具有兼容广播的形状:

    A.shape == (20, 1, 3)
    B.shape ==     (4, 3)
    

    (3 等于3 然后A 中的下一个长度是1 可以与任何长度配对。B 的维度少于A 并不重要.)

    为了使两个不兼容的数组可以相互广播,可以在一个或两个数组中插入额外的维度。使用 Nonenp.newaxis 索引维度会在数组中插入长度为 1 的额外维度。


    让我们看一下问题中的示例。 Python 从左到右计算重复乘法:

    • W[:, :, None] 有形状 (784, 20, 1)
    • h[None, :, :] 有形状 ( 1, 20, 100)

    这些形状根据上面的解释是可广播的,乘法返回一个形状为(784, 20, 100)的数组。

    • 最后一次乘法的数组形状,(784, 20, 100)
    • diff[:, None, :] 的形状为 (784, 1, 100)

    这两个数组的这些形状是兼容的,所以第二次乘法成功。返回一个形状为(784, 20, 100) 的数组。

    【讨论】:

    • 谢谢,这很有帮助。
    猜你喜欢
    • 2018-07-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-11-22
    • 2011-02-09
    • 1970-01-01
    • 2021-10-08
    相关资源
    最近更新 更多