【问题标题】:using Julia 1.0 findmax equivalent of numpy.argmax使用 Julia 1.0 findmax 相当于 numpy.argmax
【发布时间】:2019-04-24 09:17:30
【问题描述】:

在 Julia 中,我想为每行中的最大值找到矩阵的列索引,结果是 Vector{Int}。这是我目前的做法(Samples 有 7 列和 10,000 行):

mxindices = [ i[2] for i in findmax(Samples, dims = 2)[2]][:,1]

这可行,但感觉相当笨拙和冗长。想知道是否有更好的方法。

【问题讨论】:

  • 对不起,我更新了我的答案大约七千次。我想我现在完成了:-)
  • 如果我的回答有用且完整,请考虑投票并标记已回答的问题(单击我的回答旁边的勾号)。
  • 我可能会选择mapslicesmapslices(row -> findmax(row)[2], Samples, dims=[2])[:,1]

标签: julia


【解决方案1】:

更简单:Julia 有一个 argmax 函数,而 Julia 1.1+ 有一个 eachrow 迭代器。因此:

map(argmax, eachrow(x))

简单、易读、快速——它在我的快速测试中与 Colin 的 f3f4 的性能相匹配。

【讨论】:

  • 或:argmax.(eachrow(Samples))
  • Broadcast 在这里效率不会那么高,因为 eachrow 只是一个迭代器,因此 broadcast 当前在使用它之前将它收集到一个数组中——这是我未来想要改进的东西,但现在它的表现会更差。
  • 另外,如果你被 Julia 1.0 困住,请注意我们在这里使用的 eachrow 定义非常简单!您可以临时将其添加到您的项目中:github.com/JuliaLang/julia/blob/…
【解决方案2】:

更新:为了完整起见,我在测试套件中添加了 Matt B. 的出色解决方案(我还强制 f4 中的 transpose 生成一个新矩阵而不是惰性视图)。

这里有一些不同的方法(你的基本情况是f0):

f0(x) = [ i[2] for i in findmax(x, dims = 2)[2]][:,1]
f1(x) = getindex.(argmax(x, dims=2), 2)
f2(x) = [ argmax(vec(x[n,:])) for n = 1:size(x,1) ]
f3(x) = [ argmax(vec(view(x, n, :))) for n = 1:size(x,1) ]
f4(x) = begin ; xt = Matrix{Float64}(transpose(x)) ; [ argmax(view(xt, :, k)) for k = 1:size(xt,2) ] ; end
f5(x) = map(argmax, eachrow(x))

使用BenchmarkTools我们可以检查每个的效率(我设置了x = rand(100, 200)):

julia> @btime f0($x);
  76.846 μs (13 allocations: 4.64 KiB)

julia> @btime f1($x);
  76.594 μs (11 allocations: 3.75 KiB)

julia> @btime f2($x);
  53.433 μs (103 allocations: 177.48 KiB)

julia> @btime f3($x);
  43.477 μs (3 allocations: 944 bytes)

julia> @btime f4($x);
  73.435 μs (6 allocations: 157.27 KiB)

julia> @btime f5($x);
  43.900 μs (4 allocations: 960 bytes)

所以 Matt 的方法是相当明显的赢家,因为它似乎只是我的 f3 的语法更简洁的版本(两者可能编译成非常相似的东西,但我认为检查它会有点过分)。

我希望f4 可能有优势,尽管通过实例化transpose 创建了临时的,因为它可以对矩阵的列而不是行进行操作(Julia 是一种列主要语言,因此操作on columns 总是会更快,因为元素在内存中是同步的)。但这似乎不足以克服暂时的劣势。

注意,如果你想要完整的CartesianIndex,即每行中最大值的行和列索引,那么显然合适的解决方案就是argmax(x, dims=2)

【讨论】:

  • @LyndonWhite 我在 v1.1 上的 lastindex(::CartesianIndex{2}) 上收到一个方法错误。诚然,我对此感到惊讶。 CartesianIndex 是一个可迭代的权限,所以它应该有一个 lastindex 方法...
  • @LyndonWhite 好吧,我只是尝试迭代 CartesianIndex 并得到:ERROR: iteration is deliberately unsupported for CartesianIndex. Use I rather than I..., or use Tuple(I)...,所以显然当前行为是有充分理由的。我确实喜欢你的last 想法,但如果你也必须在其中添加Tuple,那么我认为它在语法上不会更清晰:-)
  • 迭代是有意的,lastindex 我不认为是。 idx[end] 是一件有意义的事情。您愿意提出问题吗?
  • @LyndonWhite 实际上我确实从 GitHub 问题页面开始,但被默认横幅推迟了:If you have a question or are unsure if the behavior you're experiencing is a bug, please search or post to our Discourse site。我现在打开了一个问题#31815
【解决方案3】:

Mapslices 函数也是解决这个问题的一个很好的选择:

julia> Samples = rand(10000, 7);

julia> res = mapslices(row -> findmax(row)[2], Samples, dims=[2])[:,1];

julia> res[1:10]
10-element Array{Int64,1}:
 3
 1
 3
 5
 4
 4
 1
 4
 5
 3

虽然这比 Colin 上面建议的要慢很多,但对某些人来说可能更具可读性。这基本上与您开始使用的代码完全相同,但使用 mapslices 而不是列表推导。

【讨论】:

    猜你喜欢
    • 2023-02-06
    • 1970-01-01
    • 2018-06-24
    • 1970-01-01
    • 2021-09-14
    • 2019-02-05
    • 2015-05-17
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多