【问题标题】:Python/Numpy - Vectorized implementation of this for loop?Python/Numpy - 这个 for 循环的矢量化实现?
【发布时间】:2021-02-02 16:11:07
【问题描述】:

这是一个基于跨卫星图像时间通道插值的云遮罩的昏昏欲睡实现。图像数组的形状为 (n_samples x n_months x 宽度 x 高度 x 通道)。通道不仅是 RGB,还来自不可见光谱,例如 SWIR、NIR 等。其中一个通道(或波段,在卫星图像世界中)是云遮罩,它告诉我 0 表示“没有云”,而 1024或 2048 表示该像素中的“云”。

我正在使用这个以像素为单位的云遮罩通道,通过在上个月/下个月之间的插值来更改所有剩余通道上的值。这个实现非常慢,我很难想出矢量化实现。

  1. 是否可以矢量化此实现?它是什么?
  2. 关于如何推断复杂数组操作的矢量化实现逻辑的任何建议?换句话说,我如何学习矢量化的艺术?

我是新手,请原谅我的无知。

n_samples = 1055
n_months = 12
width = 40
height = 40
channels = 13 # channel 13 is the cloud mask, based on which the first 12 channel pixels are interpolated)

# This function fills nan values in a list by interpolation
def fill_nan(y):
    nans = np.isnan(y)
    x = lambda z: z.nonzero()[0]
    y[nans]= np.interp(x(nans), x(~nans), y[~nans])
    return y

#for loop to first fill cloudy pixels with nan
for sample in range(1055):
    for temp in range(12):
        for w in range(40):
            for h in range(40):
                if Xtest[sample,temp,w,h,13] > 0:
                    Xtest[sample,temp,w,h,:12] = np.nan

#for loop to fill nan with interpolated values
for sample in range(1055):
    for w in range(40):
        for h in range(40):
            for ch in range(12):
                Xtest[sample,: , w, h, ch] = fill_nan(Xtest[sample,: , w, h, ch])

【问题讨论】:

  • 我不确定您在第二个循环和函数中要做什么,它会为随机 numpy 数组引发错误。您可能还需要共享您想要解决的示例数组(较小的数组,而不是完整大小的数组)。
  • 实际上 fill_nan 函数在 IMO 中没有意义。 np.interp() 要求在第一个和第二个参数中传递相同大小的数组。 x(nans)x(~nans) 很多时候不会出现这种情况。
  • 阿南达,我正在尝试用在后续和先前时间值中的相同像素之间插值的值填充 nan。例如,nan 是 4 月通道 5 中特定像素的值。我想从 3 月和 5 月的同一通道 5 中插入同一像素,以填充 4 月。有关此问题的简化版本,请参阅此帖子 stackoverflow.com/questions/6518811/…
  • 我的意思是我不认为你的实现是正确的。 np.interp(x(nans), x(~nans), y[~nans]) 大部分时间都会抛出错误。唯一有效的情况是y 具有相同数量的 nan 和非 nan(因为该函数在第一个和第二个 arg 中需要相同长度的 args)

标签: python numpy


【解决方案1】:

对于第一个循环,

import numpy as np

Xtest = np.random.rand(10, 3, 2, 4, 14)
Xtest_v = Xtest.copy()

for sample in range(10):
    for temp in range(3):
        for w in range(2):
            for h in range(4):
                if Xtest[sample,temp,w,h,13] > 0:
                    Xtest[sample,temp,w,h,:12] = np.nan

Xtest_v[..., :12][Xtest_v[..., 13]>0] = np.nan

print(np.nansum(Xtest))
print(np.nansum(Xtest_v))

您可以通过打印忽略 nans 的总和来验证两个数组是否相同。

【讨论】:

  • 非常感谢。它就像一个魅力。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2011-02-09
  • 2016-05-26
  • 2021-04-06
  • 2012-11-03
  • 2018-10-04
  • 1970-01-01
相关资源
最近更新 更多