【发布时间】:2021-11-25 08:44:39
【问题描述】:
我有一个形状为 (..., H,W) 的高维张量 a,即最后两个暗点是内核的高度和宽度:
import numpy as np
from numpy.lib.stride_tricks import as_strided
original=np.array([[[7,9,19,18],[10,11,20,16]],[[24,5,18,11],[6,10,45,12]]],dtype=np.float64)
a=as_strided(original, shape=(2,1,2,2,2),strides=(64,32*2,8*2,32,8),writeable=True)
>>> print(a)
[[[[[ 7. 9.]
[10. 11.]]
[[19. 18.]
[20. 16.]]]]
[[[[24. 5.]
[ 6. 10.]]
[[18. 11.]
[45. 12.]]]]]
注意a 不是连续的,它是original 的视图。我愿意:
- 找到每个 2x2 内核中最大值的索引,
- 然后通过索引获得这些最大值,
- 最后将最大值更改为 1000
要求
要求是,没有for循环,使用最少的max或argmax,值的变化也要反映在original张量中,即original中对应的值也要改为1000。
我的尝试
(在我从之前的问题中得到的帮助下)我尝试将最后两个暗淡变平,然后使用单个 argmax 来获得最大索引:
flatten=a.reshape(2,1,2,2*2)
multi_inds=flatten.argmax(-1)
i,j,k=np.indices(flatten.shape[:-1])
flatten[i, j, k,multi_inds]=1000
但是,由于a 不是连续的,flatten 不再与a 共享相同的数据,因为reshape,所以a 或original 中的最大值没有改变,您可以验证通过使用np.info。 a 和 original 具有相同的数据指针,但 flatten 不同:
>>> print(np.info(original))
...
contiguous: True
data pointer: 0x559d75071230
>>> print(np.info(a))
...
contiguous: False
data pointer: 0x559d75071230
>>> print(np.info(flatten))
...
contiguous: True
data pointer: 0x559d75280900
【问题讨论】: