【发布时间】:2021-07-05 09:30:03
【问题描述】:
我有这个用于创建掩码(布尔数组)的函数,我希望更快。
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))
这个问题的问题是如何使用 numpy.where 条件来制作它。我试过这个:
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
但是python会提高:
ValueError: invalid number of arguments
这里有一些输入断言:
- ts[n-1]
- 次[n][0]
- 次[n-1][1]
这是一个脚本作为输入:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
assert res_1 == res_2
assert t_1 > t_2
有谁知道如何完成函数“get_validity_2”并传递断言? 或者只是一个包的功能来解决这个问题?
【问题讨论】:
-
试试
numpy.where(numpy.logical_or(*[t1<ts.all()<t2 for t1, t2 in times]))。您的minimal reproducible example 应包含 ts 和时间的示例。似乎除了这些变量之外的所有内容,get_validity_2与问题无关。 -
numpy.where(numpy.logical_or(*[t1<ts.all()<t2 for t1, t2 in times]))return(array([], dtype=int64),)在脚本末尾有一个带有 ts 和 times 的示例 -
啊。我以为你在尝试解决
ValueError。 -
np.where只是在其参数中找到真值。这是合乎逻辑的,或者正在制作布尔数组。 -
我尝试获得相同的结果并且更快。我知道这可以使用 numpy.where
标签: python numpy optimization mask