Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get rows from numpy 2d where column value is maximum from group by other column?

Tags:

python

numpy

This is pretty common SQL query:

Select lines with maximum value in column X, group by group_id.

The result is for every group_id, one (first) line where column X value is maximum within group.

I have a 2D NumPy array with many columns but lets simplify it to (ID, X, Y):

import numpy as np
rows = np.array([[1 22 1236]
                 [1 11 1563]
                 [2 13 1234]
                 [2 10 1224]
                 [2 23 1111]
                 [2 23 1250]])

And I want to get:

[[1 22 1236]
 [2 23 1111]]

I am able to do it through cumbersome loop, something like:

  row_grouped_with_max = []

  max_row = rows[0]
  last_max = max_row[1]
  last_row_group = max_row[0]
  for row in rows:
    if last_max < row[1]:
        max_row = row
    if row[0] != last_row_group:      
      last_row_group = row[0]
      last_max = 0
      row_grouped_with_max.append(max_row)
  row_grouped_with_max.append(max_row)

How to do this in a clean NumPy way?

like image 698
Miro Avatar asked Dec 13 '25 23:12

Miro


1 Answers

Alternative using the pandas library (easier to manipulate ndarrays there, IMO).

In [1]: import numpy as np
   ...: import pandas as pd

In [2]: rows = np.array([[1,22,1236],
   ...:                  [1,11,1563],
   ...:                  [2,13,1234],
   ...:                  [2,10,1224],
   ...:                  [2,23,1111],
   ...:                  [2,23,1250]])
   ...: print rows
[[   1   22 1236]
 [   1   11 1563]
 [   2   13 1234]
 [   2   10 1224]
 [   2   23 1111]
 [   2   23 1250]]

In [3]: df = pd.DataFrame(rows)
   ...: print df
   0   1     2
0  1  22  1236
1  1  11  1563
2  2  13  1234
3  2  10  1224
4  2  23  1111
5  2  23  1250

In [4]: g = df.groupby([0])[1].transform(max)
   ...: print g
0    22
1    22
2    23
3    23
4    23
5    23
dtype: int32

In [5]: df2 = df[df[1] == g]
   ...: print df2
   0   1     2
0  1  22  1236
4  2  23  1111
5  2  23  1250

In [6]: df3 = df2.drop_duplicates([1])
   ...: print df3
   0   1     2
0  1  22  1236
4  2  23  1111

In [7]: mtx = df3.as_matrix()
   ...: print mtx
[[   1   22 1236]
 [   2   23 1111]]
like image 103
NullDev Avatar answered Dec 15 '25 14:12

NullDev



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!