我会试着给你一个例子。但这将基于一些假设:
- 您知道,0 跨越一个连续的矩形块。
-
a 中没有其他零。
如果您想填充一个不连续的零块,或者在您有一些其他非零值的列/行上有零,您将不得不考虑更复杂的解决方案。
解决方案:将数组b随机插入数组a,其中a==0
假设:我们知道a 为零的位置是一组具有矩形形状的连续位置。
进口
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'svg' # 'svg', 'retina'
plt.style.use('seaborn-white')
制作数据
# Make a
shape = (5,5)
a = np.zeros(shape)
a[:,-1] = np.arange(shape[0]) + 10
a[-1,:] = np.arange(shape[1]) + 10
# Make b
b = np.ones((2,2))*2
预处理
在这里我们确定a 上b 左上角元素的可能槽位置。
# Get range of positions (rows and cols) where we have zeros
target_indices = np.argwhere(a==0)
minmax = np.array([target_indices.min(axis=0), target_indices.max(axis=0)])
# Define max position (index) of topleft element of b on a
maxpos = np.dot(np.array([-1,1]), minmax) + minmax[0,:] - (np.array(b.shape) -1)
# Define min position (index) of topleft element of b on a
minpos = minmax[0,:]
列出b 在a 上的左上角位置
函数get_rand_topleftpos() 接受minpos 和maxpos 上定义可能槽位置的a 上的行和列,并为@ 返回一个随机选择的有效槽位置 987654341@。我使用size=20 创建了很多有效的随机插槽位置,然后只选择唯一的位置,这样我们就可以将它们视为图像。如果您一次只需要一个插槽位置,请选择size=1。
def get_rand_topleftpos(minpos, maxpos, size=1):
rowpos = np.random.randint(minpos[0], high=maxpos[0] + 1, size=size)
colpos = np.random.randint(minpos[1], high=maxpos[1] + 1, size=size)
pos = np.vstack([rowpos, colpos]).T
return (rowpos, colpos, pos)
# Make a few valid positions where the array b could be placed
rowpos, colpos, pos = get_rand_topleftpos(minpos, maxpos, size=20)
# Select the Unique combinations so we could visualize them only
pos = np.unique(pos, axis=0)
将b 放在a 上并制作数字
我们创建了一个自定义函数fill_a_with_b(),在a 的某个位置用b 填充a。此位置将接受b 的左上角单元格。
def fill_a_with_b(a, b, pos = [0,0]):
aa = a.copy()
aa[slice(pos[0], pos[0] + b.shape[0]),
slice(pos[1], pos[1] + b.shape[1])] = b.copy()
return aa
# Make a few figures with randomly picked position
# for topleft position of b on a
if pos.shape[0]>6:
nrows, ncols = int(np.ceil(pos.shape[0]/6)), 6
else:
nrows, ncols = 1, pos.shape[0]
fig, axs = plt.subplots(nrows = nrows,
ncols = ncols,
figsize=(2.5*ncols,2.5*nrows))
for i, ax in enumerate(axs.flatten()):
if i<pos.shape[0]:
aa = fill_a_with_b(a, b, pos[i,:])
sns.heatmap(aa,
vmin=np.min(aa),
vmax=np.max(aa),
annot=True,
cbar=False,
square=True,
cmap = 'YlGnBu_r',
ax = ax
);
ax.set_title('TopLeftPos: {}'.format(tuple(pos[i,:])),
fontsize=9);
else:
ax.axis('off')
plt.tight_layout()
plt.show()
结果
将数组a定义为:
shape = (5,5)
a = np.zeros(shape)
a[:,-1] = np.arange(shape[0]) + 10
a[-1,:] = np.arange(shape[1]) + 10
将数组a定义为:
shape = (6,5)
a = np.zeros(shape)
a[:,0] = np.arange(shape[0]) + 10
a[:,-1] = np.arange(shape[0]) + 10
a[-1,:] = np.arange(shape[1]) + 10