I am trying to filter a dataframe to find the first occurrence of a maximum value over a category column. In my data there is no guarantee that there is a single unique maximum value, there could be multiple values, but i only need the first occurance.
Yet I can't seem to find a way to limit the max part of the filter, currently I am then adding a further filter on another column generally a time based one and taking the minimum value.
import polars as pl
df = pl.DataFrame(
{
"cat": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
"max_col": [12, 24, 36, 15, 50, 50, 45, 20, 40, 60],
"other_col": [25, 50, 75, 125, 150, 175, 200, 225, 250, 275],
}
)
df = df.filter(pl.col("max_col") == pl.col("max_col").max().over("cat")).filter(
pl.col("other_col") == pl.col("other_col").min().over("cat")
)
shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════════╪═══════════╡
│ 1 ┆ 36 ┆ 75 │
│ 2 ┆ 50 ┆ 150 │
│ 3 ┆ 60 ┆ 275 │
└─────┴─────────┴───────────┘
However, I'd prefer to simplify the above to only require passing in references to the max and category columns.
Am I missing something obvious here?
You can add .is_first_distinct() to the filter to keep only the first max.
df.filter(
pl.all_horizontal(
pl.col("max_col") == pl.col("max_col").max(),
pl.col("max_col").is_first_distinct()
)
.over("cat")
)
shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════════╪═══════════╡
│ 1 ┆ 36 ┆ 75 │
│ 2 ┆ 50 ┆ 150 │
│ 3 ┆ 60 ┆ 275 │
└─────┴─────────┴───────────┘
I'd suggest the following approach:
DataFrame.group_by() by cat column firstExpr.sort_by() within the group by max_colExpr.first() row.There're could be 2 options:
If you want first occurence of max(max_col)
In this case you want to use maintain_order = True during sorting:
(
df
.group_by('cat', maintain_order=True)
.agg(
pl.all()
.sort_by('max_col', descending=True, maintain_order=True).first()
)
)
or, using DataFrame.sort() and group_by.first():
(
df
.sort('max_col', descending=True, maintain_order=True)
.group_by('cat', maintain_order=True)
.first()
)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════════╪═══════════╡
│ 1 ┆ 36 ┆ 75 │
│ 2 ┆ 50 ┆ 150 │
│ 3 ┆ 60 ┆ 275 │
└─────┴─────────┴───────────┘
If you need to rely on the order of other_col
In this case you can additionally sort by other_col:
(
df
.group_by('cat', maintain_order=True)
.agg(
pl.all()
.sort_by('max_col','other_col', descending=[True, False]).first()
)
)
# or
(
df
.sort('max_col','other_col', descending=[True, False])
.group_by('cat', maintain_order=True)
.first()
)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════════╪═══════════╡
│ 3 ┆ 60 ┆ 275 │
│ 2 ┆ 50 ┆ 150 │
│ 1 ┆ 36 ┆ 75 │
└─────┴─────────┴───────────┘
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With