【发布时间】:2022-01-19 09:23:41
【问题描述】:
我有一个形状为 [batch_size, timesteps_per_samples, width, height] 的 numpy 数组,其中宽度和高度指的是 2D 网格。此数组中的值可以解释为随时间变化的特定位置的高程。
我想知道这个数组中各种路径的海拔高度。因此,我有第二个形状数组[batch_size, paths_per_batch_sample, timesteps_per_path, coordinates](坐标 = 2,对于 2D 平面中的 x 和 y)。
结果数组的形状应为[batch_size, paths_per_batch_sample, timesteps_per_path],其中包含批次中每个样本随时间变化的高度。
以下两个示例有效。第一个非常慢,只是为了理解我想要做什么。我认为第二个可以满足我的要求,但我不知道为什么会这样,也不知道在某些情况下它是否会崩溃。
问题设置代码:
import numpy as np
batch_size=32
paths_per_batch_sample=10
timesteps_per_path=4
width=64
height=64
elevation = np.arange(0, batch_size*timesteps_per_path*width*height, 1)
elevation = elevation.reshape(batch_size, timesteps_per_path, width, height)
paths = np.random.randint(0, high=width-1, size=(batch_size, paths_per_batch_sample, timesteps_per_path, 2))
range_batch = range(batch_size)
range_paths = range(paths_per_batch_sample)
range_timesteps = range(timesteps_per_path)
以下代码可以运行,但速度很慢:
elevation_per_time = np.zeros((batch_size, paths_per_batch_sample, timesteps_per_path))
for s in range_batch:
for k in range_paths:
for t in range_timesteps:
x_co, y_co = paths[s,k,t,:].astype(int)
elevation_per_time[s,k,t] = elevation[s,t,x_co,y_co]
以下代码有效(甚至很快),但我不明白为什么以及如何 o.0
elevation_per_time_fast = elevation[
:,
range_timesteps,
paths[:, :, range_timesteps, 0].astype(int),
paths[:, :, range_timesteps, 1].astype(int),
][range_batch, range_batch, :, :]
证明结果相等
check = (elevation_per_time == elevation_per_time_fast)
print(np.all(check))
有人可以解释我如何将一个 nd 数组分割为多个其他数组吗? 特别是,我不明白 numpy 如何知道“range_timesteps”必须逐步运行(对于轴 1、2、3 中的索引)。
提前致谢!
【问题讨论】:
-
因为它不仅仅是索引,它的切片。结果仍然是一个具有多个维度的数组,因此您之后仍然可以再次对其进行索引。
-
@Eumel 我相应地更新了主题。你能回答最后一个问题吗?