Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Polars Dataframe: How do i drop alternate rows by group?

I have a sorted dataframe with a column that represents a group. How do I filter it to remove all the alternate rows by the group. The dataframe length is guaranteed to be an even number if it matters.

Sample Input: 
┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 1         ┆ 20        │
│ 1         ┆ 30        │
│ 1         ┆ 40        │
│ 2         ┆ 50        │
│ 2         ┆ 60        │
│ 3         ┆ 70        │
│ 3         ┆ 80        │
└───────────┴───────────┘

Output:
┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 1         ┆ 30        │
│ 2         ┆ 50        │
│ 3         ┆ 70        │
└───────────┴───────────┘

df = pl.DataFrame({
        'group_col': [1,  1,  1,  1,  2,   2, 3,   3],
        'value_col': [10, 20, 30, 40, 50, 60, 70, 80]
    })

So i would like to retain the odd number rows for every group_col.

Polars version is 0.19.

like image 665
MantleMan Avatar asked Oct 20 '25 04:10

MantleMan


2 Answers

You could use filter with cum_count+over and mod:

out = df.filter(pl.col('group_col').cum_count().over('group_col').mod(2)==1)

Output:

shape: (4, 2)
┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 2         ┆ 50        │
│ 1         ┆ 30        │
│ 3         ┆ 70        │
└───────────┴───────────┘

If you have Nulls and want to treat them as any other group:

df = pl.DataFrame({
        'group_col': [1,  1,  1,  1,  2,   2, None,   None],
        'value_col': [10, 20, 30, 40, 50, 60, 70, 80]
    })

out = df.filter(pl.int_range(pl.len()).over('group_col').mod(2)==0)

Output:

┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 1         ┆ 30        │
│ 2         ┆ 50        │
│ null      ┆ 70        │
└───────────┴───────────┘

Of course, if your groups all have an even number of values and are already sorted like in your example, you can skip the over:

out = df.filter(pl.int_range(pl.len()).mod(2)==0)
like image 199
mozway Avatar answered Oct 21 '25 18:10

mozway


.gather_every() exists.

In order to use it with .over() - you can change the default mapping_strategy= to explode

df.select(
   pl.all().gather_every(2).over("group_col", mapping_strategy="explode")
)
shape: (4, 2)
┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 1         ┆ 30        │
│ 2         ┆ 50        │
│ 3         ┆ 70        │
└───────────┴───────────┘

It's essentially the same as:

(df.group_by("group_col", maintain_order=True)
   .agg(pl.all().gather_every(2))
   .explode(pl.exclude("group_col"))
)

For the simplified case, i.e. a guaranteed even number of group rows - you could use the frame-level method.

df.gather_every(2)
shape: (4, 2)
┌───────────┬───────────┐
│ group_col ┆ value_col │
│ ---       ┆ ---       │
│ i64       ┆ i64       │
╞═══════════╪═══════════╡
│ 1         ┆ 10        │
│ 1         ┆ 30        │
│ 2         ┆ 50        │
│ 3         ┆ 70        │
└───────────┴───────────┘
like image 30
jqurious Avatar answered Oct 21 '25 17:10

jqurious