【发布时间】:2019-06-19 22:23:38
【问题描述】:
我可以在以下情况下使用tf.matmul(A, B) 进行批量矩阵乘法:
-
A.shape == (..., a, b)和 -
B.shape == (..., b, c),
... 相同。
但我想要一个额外的广播:
-
A.shape == (a, b, 2, d)和 B.shape == (a, 1, d, c)result.shape == (a, b, 2, c)
我希望结果是a x b 在(2, d) 和(d, c) 之间的矩阵乘法批次。
如何做到这一点?
测试代码:
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = np.random.uniform(0, 1, x_shape)
y = np.random.uniform(0, 1, y_shape)
z = np.empty(z_shape)
with tf.Session() as sess:
for i in range(b):
x_now = x[:, i, :, :]
z[:, i, :, :] = sess.run(
tf.matmul(x_now, y)
)
print(z)
【问题讨论】:
-
B和y有不同的形状?我不知道tf,但numpyA@B有效。 -
是的。对于
numpy、x @ y[:, np.newaxis, :, :]有效。这也适用于张量流。我不知道@在 GPU 上的 tensorflow 中的效率如何。
标签: python numpy tensorflow matrix-multiplication