这是使用 TensorFlow 的一种方法:
import tensorflow as tf
data = tf.placeholder(tf.float32, [None, None])
n = tf.placeholder(tf.int32, [])
eye = tf.eye(n)
mult = data[:, tf.newaxis, :, tf.newaxis] * eye[tf.newaxis, :, tf.newaxis, :]
result = tf.reshape(mult, n * tf.shape(data))
with tf.Session() as sess:
a = sess.run(result, feed_dict={data: [[1, 2], [3, 4]], n: 3})
print(a)
输出:
[[1. 0. 0. 2. 0. 0.]
[0. 1. 0. 0. 2. 0.]
[0. 0. 1. 0. 0. 2.]
[3. 0. 0. 4. 0. 0.]
[0. 3. 0. 0. 4. 0.]
[0. 0. 3. 0. 0. 4.]]
顺便说一句,你可以在 NumPy 中做基本相同的事情,这应该比你当前的解决方案更快:
import numpy as np
data = np.array([[1, 2], [3, 4]])
n = 3
eye = np.eye(n)
mult = data[:, np.newaxis, :, np.newaxis] * eye[np.newaxis, :, np.newaxis, :]
result = np.reshape(mult, (n * data.shape[0], n * data.shape[1]))
print(result)
# The output is the same as above
编辑:
我会尝试给出一些关于为什么/如何工作的直觉,如果它太长,对不起。这并不难,但我认为解释起来有点棘手。也许更容易看出下面的乘法是如何工作的
import numpy as np
data = np.array([[1, 2], [3, 4]])
n = 3
eye = np.eye(n)
mult1 = data[:, :, np.newaxis, np.newaxis] * eye[np.newaxis, np.newaxis, :, :]
现在,mult1 是一种“矩阵矩阵”。如果我给出两个索引,我将得到原始索引中对应元素的对角矩阵:
print(mult1[0, 0])
# [[1. 0. 0.]
# [0. 1. 0.]
# [0. 0. 1.]]
所以你可以说这个矩阵可以像这样可视化:
| 1 0 0 | | 2 0 0 |
| 0 1 0 | | 0 2 0 |
| 0 0 1 | | 0 0 2 |
| 3 0 0 | | 4 0 0 |
| 0 3 0 | | 0 4 0 |
| 0 0 3 | | 0 0 4 |
然而这是骗人的,因为如果你试图把它重塑成最终的形状,结果就不是正确的:
print(np.reshape(mult1, (n * data.shape[0], n * data.shape[1])))
# [[1. 0. 0. 0. 1. 0.]
# [0. 0. 1. 2. 0. 0.]
# [0. 2. 0. 0. 0. 2.]
# [3. 0. 0. 0. 3. 0.]
# [0. 0. 3. 4. 0. 0.]
# [0. 4. 0. 0. 0. 4.]]
原因是重塑(概念上)首先“展平”阵列,然后给出新的形状。但是这种情况下的扁平数组不是你需要的:
print(mult1.ravel())
# [1. 0. 0. 0. 1. 0. 0. 0. 1. 2. 0. 0. 0. 2. 0. ...
你看,它首先遍历第一个子矩阵,然后是第二个,等等。但你想要的是它首先遍历第一个子矩阵的第一行,然后是第二个子矩阵的第一行,然后是第二行第一个子矩阵等。所以基本上你想要这样的东西:
- 取前两个子矩阵(带有
1 和2 的子矩阵)
- 获取所有第一行(
[1, 0, 0] 和 [2, 0, 0])。
然后继续剩下的。因此,如果您考虑一下,我们首先遍历轴 0(“矩阵矩阵”的行),然后是 2(每个子矩阵的行),然后是 1(“矩阵矩阵”的列),最后是 3(子矩阵的列) )。所以我们可以重新排序轴来做到这一点:
mult2 = mult1.transpose((0, 2, 1, 3))
print(np.reshape(mult2, (n * data.shape[0], n * data.shape[1])))
# [[1. 0. 0. 2. 0. 0.]
# [0. 1. 0. 0. 2. 0.]
# [0. 0. 1. 0. 0. 2.]
# [3. 0. 0. 4. 0. 0.]
# [0. 3. 0. 0. 4. 0.]
# [0. 0. 3. 0. 0. 4.]]
而且它有效!所以在我发布的解决方案中,为了避免转置,我只是做乘法,所以轴的顺序就是这样:
mult = data[
:, # Matrix-of-matrices rows
np.newaxis, # Submatrix rows
:, # Matrix-of-matrices columns
np.newaxis # Submatrix columns
] * eye[
np.newaxis, # Matrix-of-matrices rows
:, # Submatrix rows
np.newaxis, # Matrix-of-matrices columns
: # Submatrix columns
]
我希望这能让它更清晰一些。老实说,特别是在这种情况下,我可以很快想出解决方案,因为不久前我必须解决一个类似的问题,我猜你最终会建立对这些事情的直觉。