好吧,如果您知道需要提取哪些子矩阵,tf.slice() 是最佳选择。
文档是here
对于您提供的示例,使用 tf.slice() 的解决方案是:
import tensorflow as tf
x = [[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]
X = tf.Variable(x)
s1 = tf.slice(X, [2,0], [2,2])
s1 = tf.slice(X, [0,2], [2,2])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run([s1, s1]))
此代码呈现以下结果:
[array([[1, 1], [1, 1]], dtype=int32),
array([[1, 1], [1, 1]], dtype=int32)]
编辑:
对于更自动且不那么冗长的方式,您可以使用 tensorflow 中的 getitem 属性并像对 npArray 进行切片一样对其进行切片。
代码可能是这样的:
import tensorflow as tf
var = [[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]
X = tf.Variable(var)
slices = [[0,2], [2,0]]
s = []
for sli in slices:
y = sli[0]
x = sli[1]
s.append(X[y:y+2, x:x+2])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(s))
此代码呈现以下结果:
[array([[1, 1], [1, 1]], dtype=int32),
array([[1, 1], [1, 1]], dtype=int32)]