【发布时间】:2014-10-24 00:33:11
【问题描述】:
我正在尝试改进下面这段代码的性能。最终它将使用更大的数组,但我想我会从一些简单的工作开始,然后看看哪里是慢的,优化它然后在全尺寸上尝试。这是原始代码:
#Minimum example with random variables
import numpy as np
import matplotlib.pyplot as plt
n=4
# Theoretical Travel Time to each station
ttable=np.array([1,2,3,4])
# Seismic traces,measured at each station
traces=np.random.random((n, 506))
dt=0.1
# Forward Problem add energy to each trace at the deserired time from a given origin time
given_origin_time=1
for i in range(n):
# Energy will arrive at the sample equivelant to origin time + travel time
arrival_sample=int(round((given_origin_time+ttable[i])/dt))
traces[i,arrival_sample]=2
# The aim is to find the origin time by trying each possible origin time and adding the energy up.
# Where this "Stack" is highest is likely to be the origin time
# Find the maximum travel time
tmax=ttable.max()
# We pad the traces to avoid when we shift by a travel time that the trace has no value
traces=np.lib.pad(traces,((0,0),(round(tmax/dt),round(tmax/dt))),'constant',constant_values=0)
#Available origin times to search for relative to the beginning of the trace
origin_times=np.linspace(-tmax,len(traces),len(traces)+round(tmax/dt))
# Create an empty array to fill with our stack
S=np.empty((origin_times.shape[0]))
# Loop over all the potential origin times
for l,otime in enumerate(origin_times):
# Create some variables which we will sum up over all stations
sum_point=0
sqrr_sum_point=0
# Loop over each station
for m in range(n):
# Find the appropriate travel time
ttime=ttable[m]
# Grap the point on the trace that corresponds to this travel time + the origin time we are searching for
point=traces[m,int(round((tmax+otime+ttime)/dt))]
# Sum up the points
sum_point+=point
# Sum of the square of the points
sqrr_sum_point+=point**2
# Create the stack by taking the square of the sums dived by sum of the squares normalised by the number of stations
S[l]=sum_point#**2/(n*sqrr_sum_point)
# Plot the output the peak should be at given_origin_time
plt.plot(origin_times,S)
plt.show()
我认为我不理解多维数组的广播和索引的问题。在此之后,我将扩展维度以搜索 x、y、z,这将通过增加维度 ttable 来给出。我可能会尝试实现 pytables 或 np.memmap 来帮助处理大型数组。
【问题讨论】:
-
你没有告诉我们
ttable所以我们不能轻易地帮助你。另外,请尝试将示例代码缩减到最低限度,使用最少的维度和变量来演示问题。 -
我已经编辑了问题以定义一些示例数组,希望对您有所帮助。如果您还有其他需要,请告诉我。
-
你的例子还远远不够。我打赌你可以删掉一半的代码,并且仍然有一个有用的程序来解决你的循环问题。例如,您有
S2,这似乎不会为问题添加任何内容,并且您有一个完整的额外维度,可以在不损害问题的情况下删除。 -
直接的问题是
origin_times_int+ttime不起作用,因为一个有 250 个元素,另一个有四个。我不确定您要使用该代码做什么。显然,在我们尝试解决这个怪物之前,我们需要遍历更小的代码 sn-ps。 :) -
我同意@JohnZwinck。我发誓我很想了解您的代码在做什么。这似乎是某种卷积和/或统计。但是变量名没有多大帮助。你认为你可以改进它们吗?
标签: python arrays optimization numpy multidimensional-array