【问题标题】:Vectorizing for loops in python with numpy multidimensional arrays使用 numpy 多维数组对 python 中的循环进行矢量化
【发布时间】:2014-10-24 00:33:11
【问题描述】:

我正在尝试改进下面这段代码的性能。最终它将使用更大的数组,但我想我会从一些简单的工作开始,然后看看哪里是慢的,优化它然后在全尺寸上尝试。这是原始代码:

#Minimum example with random variables
import numpy as np
import matplotlib.pyplot as plt

n=4
# Theoretical Travel Time to each station
ttable=np.array([1,2,3,4])
# Seismic traces,measured at each station
traces=np.random.random((n, 506))
dt=0.1
# Forward Problem add energy to each trace at the deserired time from a given origin time
given_origin_time=1
for i in range(n):
    # Energy will arrive at the sample equivelant to origin time + travel time
    arrival_sample=int(round((given_origin_time+ttable[i])/dt))
    traces[i,arrival_sample]=2

# The aim is to find the origin time by trying each possible origin time and adding the energy up. 
# Where this "Stack" is highest is likely to be the origin time

# Find the maximum travel time
tmax=ttable.max()


# We pad the traces to avoid when we shift by a travel time that the trace has no value
traces=np.lib.pad(traces,((0,0),(round(tmax/dt),round(tmax/dt))),'constant',constant_values=0)

#Available origin times to search for relative to the beginning of the trace
origin_times=np.linspace(-tmax,len(traces),len(traces)+round(tmax/dt))

# Create an empty array to fill with our stack
S=np.empty((origin_times.shape[0]))

# Loop over all the potential origin times
for l,otime in enumerate(origin_times):
    # Create some variables which we will sum up over all stations
    sum_point=0
    sqrr_sum_point=0
    # Loop over each station
    for m in range(n):
        # Find the appropriate travel time
        ttime=ttable[m] 
        # Grap the point on the trace that corresponds to this travel time + the origin time we are searching for 
        point=traces[m,int(round((tmax+otime+ttime)/dt))]
        # Sum up the points
        sum_point+=point
        # Sum of the square of the points
        sqrr_sum_point+=point**2
    # Create the stack by taking the square of the sums dived by sum of the squares normalised by the number of stations
    S[l]=sum_point#**2/(n*sqrr_sum_point)

# Plot the output the peak should be at given_origin_time
plt.plot(origin_times,S)
plt.show()

我认为我不理解多维数组的广播和索引的问题。在此之后,我将扩展维度以搜索 x、y、z,这将通过增加维度 ttable 来给出。我可能会尝试实现 pytables 或 np.memmap 来帮助处理大型数组。

【问题讨论】:

  • 你没有告诉我们ttable所以我们不能轻易地帮助你。另外,请尝试将示例代码缩减到最低限度,使用最少的维度和变量来演示问题。
  • 我已经编辑了问题以定义一些示例数组,希望对您有所帮助。如果您还有其他需要,请告诉我。
  • 你的例子还远远不够。我打赌你可以删掉一半的代码,并且仍然有一个有用的程序来解决你的循环问题。例如,您有S2,这似乎不会为问题添加任何内容,并且您有一个完整的额外维度,可以在不损害问题的情况下删除。
  • 直接的问题是origin_times_int+ttime 不起作用,因为一个有 250 个元素,另一个有四个。我不确定您要使用该代码做什么。显然,在我们尝试解决这个怪物之前,我们需要遍历更小的代码 sn-ps。 :)
  • 我同意@JohnZwinck。我发誓我很想了解您的代码在做什么。这似乎是某种卷积和/或统计。但是变量名没有多大帮助。你认为你可以改进它们吗?

标签: python arrays optimization numpy multidimensional-array


【解决方案1】:

通过一些快速分析,似乎该行

point=traces[m,int(round((tmax+otime+ttime)/dt))]

占用了整个程序运行时间的约 40%。让我们看看我们是否可以将它矢量化一下:

    ttime_inds = np.around((tmax + otime + ttable) / dt).astype(int)
    # Loop over each station
    for m in range(n):
        # Grap the point on the trace that corresponds to this travel time + the origin time we are searching for 
        point=traces[m,ttime_inds[m]]

我们注意到循环中唯一改变的东西(m 除外)是 ttime,因此我们将其取出并使用 numpy 函数对该部分进行矢量化。

那是最大的热点,但我们可以更进一步,完全移除内部循环:

# Loop over all the potential origin times
for l,otime in enumerate(origin_times):
    ttime_inds = np.around((tmax + otime + ttable) / dt).astype(int)
    points = traces[np.arange(n),ttime_inds]
    sum_point = points.sum()
    sqrr_sum_point = (points**2).sum()
    # Create the stack by taking the square of the sums dived by sum of the squares normalised by the number of stations
    S[l]=sum_point#**2/(n*sqrr_sum_point)

编辑:如果你也想去掉外循环,我们需要把otime拉出来:

ttime_inds = np.around((tmax + origin_times[:,None] + ttable) / dt).astype(int)

然后,我们像以前一样继续,在第二个轴上求和:

points = traces[np.arange(n),ttime_inds]
sum_points = points.sum(axis=1)
sqrr_sum_points = (points**2).sum(axis=1)
S = sum_points # **2/(n*sqrr_sum_points)

【讨论】:

  • 太好了,有没有办法去掉下一个循环。因为计划是让这个运行在相当大的数组上,并做更多维度的相同事情。
  • 当然,我已经更新了我的答案。这将使事情变得更快,但请记住,它也会使用更多内存。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2012-07-01
  • 1970-01-01
  • 1970-01-01
  • 2013-04-15
  • 1970-01-01
  • 2023-03-10
  • 2020-01-11
相关资源
最近更新 更多