改编 https://stackoverflow.com/a/18860925/2062663 的建议,您可以尝试使用 Numpy 的 meshgrid 如下:
import numpy as np
def get_mask(shape):
i, j, k, l = np.meshgrid(*map(np.arange, shape),
sparse=1,
indexing='ij')
return i + j == k + l
这里有几个测试用例:
def test_1():
expected = np.array([[[[1]]]])
actual = get_mask((1, 1, 1, 1))
assert actual == expected
def test_2():
expected = np.array([[[[1, 0], [0, 0]],
[[0, 1], [1, 0]]],
[[[0, 1], [1, 0]],
[[0, 0], [0, 1]]]])
actual = get_mask((2, 2, 2, 2))
assert (actual == expected).all()
def test_4():
actual = get_mask((4, 4, 4, 4))
assert actual[1, 2, 0, 3] == 1
assert actual[1, 1, 2, 0] == 1
assert actual[0, 2, 1, 3] == 0
assert actual[3, 2, 0, 3] == 0