Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Polars: Replace parts of dataframe with other parts of dataframe

I'm looking for an efficient way to copy / replace parts of a dataframe with other parts of the same dataframe in Polars.

For instance, in the following minimal example dataframe

pl.DataFrame({
  "year": [2020,2021,2020,2021],
  "district_id": [1,2,1,2],
  "distribution_id": [1, 1, 2, 2],
  "var_1": [1,2,0.1,0.3],
  "var_N": [1,2,0.3,0.5],
  "unrelated_var": [0.2,0.5,0.3,0.7],
})

I'd like to replace all column values of "var_1" & "var_N" where the "distribution_id" = 2 with the corresponding values where the "distribution_id" = 1.

This is the desired result:

pl.DataFrame({
  "year": [2020,2021,2020,2021],
  "district_id": [1,2,1,2],
  "distribution_id": [1, 1, 2, 2],
  "var_1": [1,2,1,2],
  "var_N": [1,2,1,2],
  "unrelated_var": [0.2,0.5,0.3,0.7],
})

I tried to use a "when" expression, but it fails with "polars.exceptions.ShapeError: shapes of self, mask and other are not suitable for zip_with operation"

df = df.with_columns([
  pl.when(pl.col("distribution_id") == 2).then(df.filter(pl.col("distribution_id") == 1).otherwise(pl.col(col)).alias(col) for col in columns_to_copy
  ]
)

Here's what I used to do with SQLAlchemy:

table_alias = table.alias("table_alias")
stmt = table.update().\
    where(table.c.year == table_alias.c.year).\
    where(table.c.d_id == table_alias.c.d_id).\
    where(table_alias.c.distribution_id == 1).\
    where(table.c.distribution_id == 2).\
    values(var_1=table_alias.c.var_1,
           var_n=table_alias.c.var_n)

Thanks a lot for you help!

like image 221
Christoph Pahmeyer Avatar asked Dec 12 '25 00:12

Christoph Pahmeyer


1 Answers

You could filter the 1 columns, change their id to 2 and discard the unneeded columns.

df.filter(distribution_id = 1).select(
   "year", "district_id", "^var_.+$", distribution_id = pl.lit(2, pl.Int64)
)
shape: (2, 5)
┌──────┬─────────────┬───────┬───────┬─────────────────┐
│ year ┆ district_id ┆ var_1 ┆ var_N ┆ distribution_id │
│ ---  ┆ ---         ┆ ---   ┆ ---   ┆ ---             │
│ i64  ┆ i64         ┆ f64   ┆ f64   ┆ i64             │
╞══════╪═════════════╪═══════╪═══════╪═════════════════╡
│ 2020 ┆ 1           ┆ 1.0   ┆ 1.0   ┆ 2               │
│ 2021 ┆ 2           ┆ 2.0   ┆ 2.0   ┆ 2               │
└──────┴─────────────┴───────┴───────┴─────────────────┘
  • (note: "^var_.+$" selects columns by regex, but selectors can be used if preferred.)

With the data "aligned", you can pass it to .update()

df.update(
   df.filter(distribution_id = 1)
     .select("year", "district_id", "^var_.+$", distribution_id = pl.lit(2, pl.Int64)),
   on=["year", "district_id", "distribution_id"]
)
shape: (4, 6)
┌──────┬─────────────┬─────────────────┬───────┬───────┬───────────────┐
│ year ┆ district_id ┆ distribution_id ┆ var_1 ┆ var_N ┆ unrelated_var │
│ ---  ┆ ---         ┆ ---             ┆ ---   ┆ ---   ┆ ---           │
│ i64  ┆ i64         ┆ i64             ┆ f64   ┆ f64   ┆ f64           │
╞══════╪═════════════╪═════════════════╪═══════╪═══════╪═══════════════╡
│ 2020 ┆ 1           ┆ 1               ┆ 1.0   ┆ 1.0   ┆ 0.2           │
│ 2021 ┆ 2           ┆ 1               ┆ 2.0   ┆ 2.0   ┆ 0.5           │
│ 2020 ┆ 1           ┆ 2               ┆ 1.0   ┆ 1.0   ┆ 0.3           │
│ 2021 ┆ 2           ┆ 2               ┆ 2.0   ┆ 2.0   ┆ 0.7           │
└──────┴─────────────┴─────────────────┴───────┴───────┴───────────────┘
like image 107
jqurious Avatar answered Dec 13 '25 15:12

jqurious



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!