【问题标题】:How does slicing numpy arrays with other arrays work?用其他数组切片 numpy 数组是如何工作的?
【发布时间】:2022-01-19 09:23:41
【问题描述】:

我有一个形状为 [batch_size, timesteps_per_samples, width, height] 的 numpy 数组,其中宽度和高度指的是 2D 网格。此数组中的值可以解释为随时间变化的特定位置的高程。 我想知道这个数组中各种路径的海拔高度。因此,我有第二个形状数组[batch_size, paths_per_batch_sample, timesteps_per_path, coordinates](坐标 = 2,对于 2D 平面中的 x 和 y)。

结果数组的形状应为[batch_size, paths_per_batch_sample, timesteps_per_path],其中包含批次中每个样本随时间变化的高度。 以下两个示例有效。第一个非常慢,只是为了理解我想要做什么。我认为第二个可以满足我的要求,但我不知道为什么会这样,也不知道在某些情况下它是否会崩溃。

问题设置代码:

import numpy as np

batch_size=32
paths_per_batch_sample=10
timesteps_per_path=4
width=64
height=64

elevation = np.arange(0, batch_size*timesteps_per_path*width*height, 1)
elevation = elevation.reshape(batch_size, timesteps_per_path, width, height)

paths = np.random.randint(0, high=width-1, size=(batch_size, paths_per_batch_sample, timesteps_per_path, 2))

range_batch = range(batch_size)
range_paths = range(paths_per_batch_sample)
range_timesteps = range(timesteps_per_path)

以下代码可以运行,但速度很慢:

elevation_per_time = np.zeros((batch_size, paths_per_batch_sample, timesteps_per_path))
for s in range_batch:
        for k in range_paths:
            for t in range_timesteps:
                x_co, y_co = paths[s,k,t,:].astype(int)
                elevation_per_time[s,k,t] = elevation[s,t,x_co,y_co]

以下代码有效(甚至很快),但我不明白为什么以及如何 o.0

elevation_per_time_fast = elevation[
        :,
        range_timesteps,
        paths[:, :, range_timesteps, 0].astype(int),
        paths[:, :, range_timesteps, 1].astype(int),
    ][range_batch, range_batch, :, :]

证明结果相等

check = (elevation_per_time == elevation_per_time_fast)
print(np.all(check))

有人可以解释我如何将一个 nd 数组分割为多个其他数组吗? 特别是,我不明白 numpy 如何知道“range_timesteps”必须逐步运行(对于轴 1、2、3 中的索引)。

提前致谢!

【问题讨论】:

  • 因为它不仅仅是索引,它的切片。结果仍然是一个具有多个维度的数组,因此您之后仍然可以再次对其进行索引。
  • @Eumel 我相应地更新了主题。你能回答最后一个问题吗?

标签: python numpy indexing


【解决方案1】:

让我们先快速看一下切片 numpy 数组:

a = np.arange(0,9,1).reshape([3,3])
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

Numpy 有 2 种切片数组的方法,完整部分 start:stop 和列表中的索引 [index1, index2 ...]。输出仍然是一个具有切片形状的数组:

a[0:2,:]
array([[0, 1, 2],
       [3, 4, 5]])

a[:,[0,2]]
array([[0, 2],
       [3, 5],
       [6, 8]])

第二部分是因为你得到了一个具有相同维度的返回数组,只要你不尝试直接访问数组外的索引,你就可以轻松堆叠任意数量的切片。

a[:][:][:][:][:][:][:][[0,2]][:,[0,2]]
array([[0, 2],
       [6, 8]])

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-08-30
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-06-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多