In Julia I want to find the column index of a matrix for the maximum value in each row, with the result being a Vector{Int}
. Here is how I am doing it currently (Samples
has 7 columns and 10,000 rows):
mxindices = [ i[2] for i in findmax(Samples, dims = 2)[2]][:,1]
This works but feels rather clumsy and verbose. Wondered if there was a better way.
Even simpler: Julia has an argmax
function and Julia 1.1+ has an eachrow
iterator. Thus:
map(argmax, eachrow(x))
Simple, readable, and fast — it matches the performance of Colin's f3
and f4
in my quick tests.
UPDATE: For the sake of completeness, I've added Matt B.'s excellent solution to the test-suite (and I also forced the transpose
in f4
to generate a new matrix rather than a lazy view).
Here are some different approaches (yours is the base-case f0
):
f0(x) = [ i[2] for i in findmax(x, dims = 2)[2]][:,1]
f1(x) = getindex.(argmax(x, dims=2), 2)
f2(x) = [ argmax(vec(x[n,:])) for n = 1:size(x,1) ]
f3(x) = [ argmax(vec(view(x, n, :))) for n = 1:size(x,1) ]
f4(x) = begin ; xt = Matrix{Float64}(transpose(x)) ; [ argmax(view(xt, :, k)) for k = 1:size(xt,2) ] ; end
f5(x) = map(argmax, eachrow(x))
Using BenchmarkTools
we can examine the efficiency of each (I've set x = rand(100, 200)
):
julia> @btime f0($x);
76.846 μs (13 allocations: 4.64 KiB)
julia> @btime f1($x);
76.594 μs (11 allocations: 3.75 KiB)
julia> @btime f2($x);
53.433 μs (103 allocations: 177.48 KiB)
julia> @btime f3($x);
43.477 μs (3 allocations: 944 bytes)
julia> @btime f4($x);
73.435 μs (6 allocations: 157.27 KiB)
julia> @btime f5($x);
43.900 μs (4 allocations: 960 bytes)
So Matt's approach is the fairly obvious winner, as it appears to just be a syntactically cleaner version of my f3
(the two probably compile to something very similar, but I think it would be overkill to check that).
I was hoping f4
might have an edge, despite the temporary created via instantiating the transpose
, since it could operate on the columns of a matrix rather than the rows (Julia is a column-major language, so operations on columns will always be faster since the elements are synchronous in memory). But it doesn't appear to be enough to overcome the disadvantage of the temporary.
Note, if it is ever the case that you want the full CartesianIndex
, that is, both the row and column index of the maximum in each row, then obviously the appropriate solution is just argmax(x, dims=2)
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With