Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to filter on uniqueness by condition

Imagine I have a dataset like:

data = {
    "a": [1, 4, 2, 4, 7, 4],
    "b": [4, 2, 3, 3, 0, 2],
    "c": ["a", "b", "c", "d", "e", "f"],
}

and I want to keep only the rows for which a + b is uniquely described by a single combination of a and b. I managed to hack this:

df = (
    pl.DataFrame(data)
    .with_columns(sum_ab=pl.col("a") + pl.col("b"))
    .group_by("sum_ab")
    .agg(pl.col("a"), pl.col("b"), pl.col("c"))
    .filter(
        (pl.col("a").list.unique().list.len() == 1)
        & (pl.col("b").list.unique().list.len() == 1)
    )
    .explode(["a", "b", "c"])
    .select("a", "b", "c")
)

"""
shape: (2, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
│ 4   ┆ 2   ┆ b   │
│ 4   ┆ 2   ┆ f   │
└─────┴─────┴─────┘
"""

Can someone suggest a better way to achieve the same? I struggled a bit to figure this logic out, so I imagine there is a more direct/elegant way of getting the same result.

like image 659
DJDuque Avatar asked Sep 13 '25 06:09

DJDuque


1 Answers

  • .struct() to combine a and b into one column so we can check uniqueness.
  • n_unique() to check uniqueness.
  • over() to limit the calculation to be within a + b.
df.filter(
    pl.struct("a","b").n_unique().over(pl.col.a + pl.col.b) == 1
)

┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
│ 4   ┆ 2   ┆ b   │
│ 4   ┆ 2   ┆ f   │
└─────┴─────┴─────┘

If you would need to extend it to larger number of columns then you could use sum_horizontal() to make it more generic:

columns = ["a","b"]

df.filter(
    pl.struct(columns).n_unique().over(pl.sum_horizontal(columns)) == 1
)
like image 191
Roman Pekar Avatar answered Sep 15 '25 19:09

Roman Pekar