【问题标题】:Fast 1D linear np.NaN interpolation over large 3D array大型 3D 阵列上的快速 1D 线性 np.NaN 插值
【发布时间】:2015-09-03 19:57:02
【问题描述】:

我有一个 (z, y, x)shape=(92, 4800, 4800) 的 3D 数组,其中沿 axis 0 的每个值代表不同的时间点。在某些情况下,时域中的值获取失败,导致某些值成为np.NaN。在其他情况下,没有获取任何值,并且z 上的所有值都是np.NaN

无论所有值都是np.NaN 的情况,使用线性插值沿axis 0 填充np.NaN 的最有效方法是什么?

这是我正在做的一个工作示例,它使用pandas 包装器到scipy.interpolate.interp1d。在原始数据集上,每个切片大约需要 2 秒,这意味着整个数组在 2.6 小时内处理完毕。减小大小的示例数据集大约需要 9.5 秒。

import numpy as np
import pandas as pd

# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768  # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768

def interpolate_nan(arr, method="linear", limit=3):
    """return array interpolated along time-axis to fill missing values"""
    result = np.zeros_like(arr, dtype=np.int16)

    for i in range(arr.shape[1]):
        # slice along y axis, interpolate with pandas wrapper to interp1d
        line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
        line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
        line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
        line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
        result[:, i, :] = line_stack.values.astype(np.int16)
    return result

使用示例数据集在我的机器上的性能:

%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop

编辑:

我应该澄清一下代码正在产生我的预期结果。问题是 - 我怎样才能优化这个过程?

【问题讨论】:

  • 在我的机器上运行示例大约需要 9.5 秒,但 test_arr 的形状是 (92, 480, 480)。如果将其增加到真实数据集 (92, 4800, 4800) 的大小并使用更多 NaN 传播它,则此方法需要更长的时间。

标签: python numpy pandas scipy interpolation


【解决方案1】:

这取决于;如果您插值并仅将这些 NaN 填零,您将不得不拿出一张纸并计算您的整体统计数据会出现的错误。

除此之外,我认为您的插值超出了顶部。 只需找到每个 NaN,并线性插值到相邻的四个值(即,将 (y +- 1,x +- 1) 处的值相加) - 这将严重限制您的错误(自己计算!),并且您没有使用在您的案例中使用的任何复杂方法进行插值(您没有定义 method)。

您可以尝试只预先计算每个 z 值的“平均”4800x4800 矩阵 - 这应该不会花很长时间 - 通过在矩阵上应用一个十字形内核(这一切都非常类似于图像处理, 这里)。在 NaN 的情况下,一些平均值将是 NaN(NaN 位于附近的每个平均像素),但您不在乎 - 除非有两个相邻的 NaN,否则您要替换的 NaN 单元格原始矩阵都是实值的。

然后,您只需将所有 NaN 替换为平均矩阵中的值。

将其速度与“手动”计算您找到的每个 NaN 的邻域平均值的速度进行比较。

【讨论】:

  • 插值实际上只是根据时域中的信号进行分类的前兆,因此不能选择 0 填充。线性插值是我目前使用method="linear" 所做的。如果线性插值失败,将每个 NaN 替换为其 z 轴的平均值是我的后备选项。
  • 嗯,周围两个z值的平均值正好线性插值。
  • 抱歉,我指的是沿 z 轴的所有 92 个值的平均值。否则我陷入了最初的问题:沿 z 轴对缺失值进行插值的最快方法是什么。
  • 我不太明白——我在之前的评论中概述的线性插值应该相当快,而且是有效的插值
  • 如果我正确理解了您的答案,您建议使用 2D 线性插值。我要进行一维插值(对于从 z+-1 进行的每个 NaN 插值)。这也是 pandas 包装器到 scipy.signal.interp1d 所做的。我的回答更多的是代码优化,而不是插值选择。除非我理解你的答案错误并且它更有效 - 在这种情况下:愿意用代码示例来解释它吗?
【解决方案2】:

我最近在 numba 的帮助下为我的特定用例解决了这个问题,并且还解决了a little writeup on it

from numba import jit

@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
    """return array interpolated along time-axis to fill missing values"""
    result = np.zeros_like(arr, dtype=np.int16)

    for x in range(arr.shape[2]):
        # slice along x axis
        for y in range(arr.shape[1]):
            # slice along y axis
            for z in range(arr.shape[0]):
                value = arr[z,y,x]
                if z == 0:  # don't interpolate first value
                    new_value = value
                elif z == len(arr[:,0,0])-1:  # don't interpolate last value
                    new_value = value

                elif value == no_data:  # interpolate

                    left = arr[z-1,y,x]
                    right = arr[z+1,y,x]
                    # look for valid neighbours
                    if left != no_data and right != no_data:  # left and right are valid
                        new_value = (left + right) / 2

                    elif left == no_data and z == 1:  # boundary condition left
                        new_value = value
                    elif right == no_data and z == len(arr[:,0,0])-2:  # boundary condition right
                        new_value = value

                    elif left == no_data and right != no_data:  # take second neighbour to the left
                        more_left = arr[z-2,y,x]
                        if more_left == no_data:
                            new_value = value
                        else:
                            new_value = (more_left + right) / 2

                    elif left != no_data and right == no_data:  # take second neighbour to the right
                        more_right = arr[z+2,y,x]
                        if more_right == no_data:
                            new_value = value
                        else:
                            new_value = (more_right + left) / 2

                    elif left == no_data and right == no_data:  # take second neighbour on both sides
                        more_left = arr[z-2,y,x]
                        more_right = arr[z+2,y,x]
                        if more_left != no_data and more_right != no_data:
                            new_value = (more_left + more_right) / 2
                        else:
                            new_value = value
                    else:
                        new_value = value
                else:
                    new_value = value
                result[z,y,x] = int(new_value)
    return result

这比我的初始代码快 20 倍

【讨论】:

  • 看看 cython 比较会很有趣。
  • 如果有人能够/愿意在 Cython 中编写该函数,我也很想看看速度对比。
  • 感谢您的提问和回答。我对线性插值的逻辑有一些新的想法。我在下面发布了一个新答案,感谢您的关注!
【解决方案3】:

提问者利用numba 给出了很好的答案。我真的很感激,但我不能完全同意interpolate_numba 函数中的内容。我不认为在特定点上进行线性插值的逻辑是找到其左右邻居的平均值。为了说明,假设我们有一个数组[1,nan,nan,4,nan,6],上面的interpolate_numba函数可能会返回[1,2.5,2.5,4,5,6](仅理论推导) ,而pandas wrapper 肯定会返回 [1,2,3,4,5,6]。相反,我认为对特定点进行线性插值的逻辑是找到它的左右邻居,用它们的值来确定一条线(即斜率和截距),最后计算插值。下面显示了我的代码。为了简单起见,我假设输入数据是一个包含 nan 值的 3-D 数组。我规定第一个和最后一个元素等效于它们的左右最近邻居(即pandas 中的limit_direction='both')。我没有指定连续插值的最大数量(即pandas 中没有limit)。

import numpy as np
from numba import jit
@jit(nopython=True)
def f(arr_3d):
    result=np.zeros_like(arr_3d)
    for i in range(arr_3d.shape[1]):
        for j in range(arr_3d.shape[2]):
            arr=arr_3d[:,i,j]
            # If all elements are nan then cannot conduct linear interpolation.
            if np.sum(np.isnan(arr))==arr.shape[0]:
                result[:,i,j]=arr
            else:
                # If the first elemet is nan, then assign the value of its right nearest neighbor to it.
                if np.isnan(arr[0]):
                    arr[0]=arr[~np.isnan(arr)][0]
                # If the last element is nan, then assign the value of its left nearest neighbor to it.
                if np.isnan(arr[-1]):
                    arr[-1]=arr[~np.isnan(arr)][-1]
                # If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
                for k in range(arr.shape[0]):
                    if np.isnan(arr[k]):
                        x=k
                        x1=x-1
                        x2=x+1
                        # Find left neighbor whose value is not nan.
                        while x1>=0:
                            if np.isnan(arr[x1]):
                                x1=x1-1
                            else:
                                y1=arr[x1]
                                break
                        # Find right neighbor whose value is not nan.
                        while x2<arr.shape[0]:
                            if np.isnan(arr[x2]):
                                x2=x2+1
                            else:
                                y2=arr[x2]
                                break
                        # Calculate the slope and intercept determined by the left and right neighbors.
                        slope=(y2-y1)/(x2-x1)
                        intercept=y1-slope*x1
                        # Linear interpolation and assignment.
                        y=slope*x+intercept
                        arr[x]=y
                result[:,i,j]=arr
    return result

初始化一个包含一些 nan 的 3-D 数组,我检查了我的代码,它可以给出与 pandas 包装器相同的答案。通过pandas 包装代码会有点混乱,因为pandas 只能处理二维数据。

使用我的代码

y1=np.ones((2,2))
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[1,1]=np.nan
y2[0,0]=np.nan
y4[1,1]=np.nan
y6[1,1]=np.nan
y=np.stack((y1,y2,y3,y4,y5,y6),axis=0)
print(y)
print("="*10)
f(y)

使用熊猫包装器

import pandas as pd
y1=np.ones((2,2)).flatten()
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[3]=np.nan
y2[0]=np.nan
y4[3]=np.nan
y6[3]=np.nan
y=pd.DataFrame(np.stack([y1,y2,y3,y4,y5,y6],axis=0))
y=y.interpolate(method='linear', limit_direction='both', axis=0)
y_numpy=y.to_numpy()
y_numpy.shape=((6,2,2))
print(np.stack([y1,y2,y3,y4,y5,y6],axis=0).reshape(6,2,2))
print("="*10)
print(y_numpy)

输出将是相同的

[[[ 1.  1.]
  [ 1. nan]]

 [[nan  2.]
  [ 2.  2.]]

 [[nan nan]
  [nan nan]]

 [[ 4.  4.]
  [ 4. nan]]

 [[nan nan]
  [nan nan]]

 [[ 6.  6.]
  [ 6. nan]]]
==========
[[[1. 1.]
  [1. 2.]]

 [[2. 2.]
  [2. 2.]]

 [[3. 3.]
  [3. 2.]]

 [[4. 4.]
  [4. 2.]]

 [[5. 5.]
  [5. 2.]]

 [[6. 6.]
  [6. 2.]]]

使用 test_arr 将其大小增加到 (92,4800,4800) 的数据作为输入,我发现完成插值只需要大约 40 秒!

test_arr = np.random.randint(low=-10000, high=10000, size=(92, 4800, 4800))
test_arr[1:90:7, :, :] = np.nan  # NaN fill value in original data
test_arr[2,:,:] = np.nan
test_arr[:, 1:479:6, 1:479:8] = np.nan
%time f(test_arr)

输出

CPU times: user 32.5 s, sys: 9.13 s, total: 41.6 s
Wall time: 41.6 s

【讨论】:

  • 很好的答案!我最初的问题在给定的一系列数字中只包含 1 个缺失的观察结果。但是,您的解决方案是一个不错的解决方案,可以支持更多应用程序。
  • 有道理!
猜你喜欢
  • 2019-04-12
  • 2023-03-05
  • 1970-01-01
  • 2020-01-30
  • 1970-01-01
  • 2017-04-21
  • 2013-12-14
  • 1970-01-01
  • 2014-03-03
相关资源
最近更新 更多