Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to filter polars dataframe by first maximum value while using over?

I am trying to filter a dataframe to find the first occurrence of a maximum value over a category column. In my data there is no guarantee that there is a single unique maximum value, there could be multiple values, but i only need the first occurance.

Yet I can't seem to find a way to limit the max part of the filter, currently I am then adding a further filter on another column generally a time based one and taking the minimum value.

import polars as pl

df = pl.DataFrame(
    {
        "cat": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
        "max_col": [12, 24, 36, 15, 50, 50, 45, 20, 40, 60],
        "other_col": [25, 50, 75, 125, 150, 175, 200, 225, 250, 275],
    }
)

df = df.filter(pl.col("max_col") == pl.col("max_col").max().over("cat")).filter(
    pl.col("other_col") == pl.col("other_col").min().over("cat")
)
shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 1   ┆ 36      ┆ 75        │
│ 2   ┆ 50      ┆ 150       │
│ 3   ┆ 60      ┆ 275       │
└─────┴─────────┴───────────┘

However, I'd prefer to simplify the above to only require passing in references to the max and category columns.

Am I missing something obvious here?

like image 333
niko86 Avatar asked Nov 15 '25 10:11

niko86


2 Answers

You can add .is_first_distinct() to the filter to keep only the first max.

df.filter(
    pl.all_horizontal(
        pl.col("max_col") == pl.col("max_col").max(),
        pl.col("max_col").is_first_distinct()
    )
    .over("cat")
)
shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 1   ┆ 36      ┆ 75        │
│ 2   ┆ 50      ┆ 150       │
│ 3   ┆ 60      ┆ 275       │
└─────┴─────────┴───────────┘
like image 86
jqurious Avatar answered Nov 17 '25 08:11

jqurious


I'd suggest the following approach:

  • DataFrame.group_by() by cat column first
  • Expr.sort_by() within the group by max_col
  • take the Expr.first() row.

There're could be 2 options:

If you want first occurence of max(max_col)

In this case you want to use maintain_order = True during sorting:

(
    df
    .group_by('cat', maintain_order=True)
    .agg(
        pl.all()
        .sort_by('max_col', descending=True, maintain_order=True).first()
    )
)

or, using DataFrame.sort() and group_by.first():

(
    df
    .sort('max_col', descending=True, maintain_order=True)
    .group_by('cat', maintain_order=True)
    .first()
)

┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 1   ┆ 36      ┆ 75        │
│ 2   ┆ 50      ┆ 150       │
│ 3   ┆ 60      ┆ 275       │
└─────┴─────────┴───────────┘

If you need to rely on the order of other_col

In this case you can additionally sort by other_col:

(
    df
    .group_by('cat', maintain_order=True)
    .agg(
        pl.all()
        .sort_by('max_col','other_col', descending=[True, False]).first()
    )
)

# or

(
    df
    .sort('max_col','other_col', descending=[True, False])
    .group_by('cat', maintain_order=True)
    .first()
)

┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 3   ┆ 60      ┆ 275       │
│ 2   ┆ 50      ┆ 150       │
│ 1   ┆ 36      ┆ 75        │
└─────┴─────────┴───────────┘
like image 25
Roman Pekar Avatar answered Nov 17 '25 08:11

Roman Pekar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!