I have a dataframe with two columns values
and weights
of list[i64]
dtype, and I'd like to perform row-wise dot product of the two.
df = pl.DataFrame({
'values': [[0], [0, 2], [0, 2, 4], [2, 4, 0], [4, 0, 8]],
'weights': [[3], [2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]
})
There's one way that worked, which is first putting values
and weights
into a struct
and then do .map_elements
on each row:
df.with_columns(
pl.struct(['values', 'weights'])
.map_elements(
lambda x: np.dot(x['values'], x['weights']), return_dtype=pl.Float64
).alias('dot')
)
But as the documentation points out, map_elements
is in general much slower than native polars expressions, so I was trying to implement in native expressions.
I tried the following:
df.with_columns(
pl.concat_list('values', 'weights').alias('combined'),
pl.concat_list('values', 'weights').list.eval(pl.element().slice(0, pl.len() // 2)).alias('values1'),
pl.concat_list('values', 'weights').list.eval(pl.element().slice(pl.len() // 2, pl.len() // 2)).alias('values2'),
pl.concat_list('values', 'weights').list.eval(
pl.element().slice(0, pl.len() // 2).dot(pl.element().slice(pl.len() // 2, pl.len() // 2))
).list.first().alias('dot'),
pl.concat_list('values', 'weights').list.eval(
pl.element().slice(0, pl.len() // 2) + pl.element().slice(pl.len() // 2, pl.len() // 2)
).alias('sum'),
)
I was expecting the dot
column to be [0, 6, 16, 10, 28]
, but it turns out to be the following.
shape: (5, 7)
┌───────────┬───────────┬─────────────┬───────────┬───────────┬─────┬────────────┐
│ values ┆ weights ┆ combined ┆ values1 ┆ values2 ┆ dot ┆ sum │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] ┆ list[i64] ┆ list[i64] ┆ i64 ┆ list[i64] │
╞═══════════╪═══════════╪═════════════╪═══════════╪═══════════╪═════╪════════════╡
│ [0] ┆ [3] ┆ [0, 3] ┆ [0] ┆ [3] ┆ 0 ┆ [0] │
│ [0, 2] ┆ [2, 3] ┆ [0, 2, … 3] ┆ [0, 2] ┆ [2, 3] ┆ 4 ┆ [0, 4] │
│ [0, 2, 4] ┆ [1, 2, 3] ┆ [0, 2, … 3] ┆ [0, 2, 4] ┆ [1, 2, 3] ┆ 20 ┆ [0, 4, 8] │
│ [2, 4, 0] ┆ [1, 2, 3] ┆ [2, 4, … 3] ┆ [2, 4, 0] ┆ [1, 2, 3] ┆ 20 ┆ [4, 8, 0] │
│ [4, 0, 8] ┆ [1, 2, 3] ┆ [4, 0, … 3] ┆ [4, 0, 8] ┆ [1, 2, 3] ┆ 80 ┆ [8, 0, 16] │
└───────────┴───────────┴─────────────┴───────────┴───────────┴─────┴────────────┘
Note that even the sum
isn't what I expect it to be. The first slice seems to be adding itself instead of the second slice
Am I doing anything wrong? What's the best way to perform row-wise dot product in Polars?
updated. With current version 1.10.0
arithmetic operations between lists are supported:
df.with_columns(
dot = (pl.col.values * pl.col.weights).list.sum()
)
shape: (5, 3)
┌───────────┬───────────┬─────┐
│ values ┆ weights ┆ dot │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ i64 │
╞═══════════╪═══════════╪═════╡
│ [0] ┆ [3] ┆ 0 │
│ [0, 2] ┆ [2, 3] ┆ 6 │
│ [0, 2, 4] ┆ [1, 2, 3] ┆ 16 │
│ [2, 4, 0] ┆ [1, 2, 3] ┆ 10 │
│ [4, 0, 8] ┆ [1, 2, 3] ┆ 28 │
└───────────┴───────────┴─────┘
oudated. Unfortunately, polars doesn't have dot product method for lists yet and list.eval()
is somehow limited.
One possible solution without using explode()
could be
.concat_list()
to combine 2 lists into one.list.eval()
..shift()
to get a value from the second list.df.with_columns(
dot= pl.concat_list(pl.all()).list.eval(
(pl.element() * pl.element().shift(pl.element().len() // 2)).sum()
).list.first()
)
shape: (5, 3)
┌───────────┬───────────┬─────┐
│ values ┆ weights ┆ dot │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ i64 │
╞═══════════╪═══════════╪═════╡
│ [0] ┆ [3] ┆ 0 │
│ [0, 2] ┆ [2, 3] ┆ 6 │
│ [0, 2, 4] ┆ [1, 2, 3] ┆ 16 │
│ [2, 4, 0] ┆ [1, 2, 3] ┆ 10 │
│ [4, 0, 8] ┆ [1, 2, 3] ┆ 28 │
└───────────┴───────────┴─────┘
But I think performance might not be the best.
Another way of doing it could be using of DuckDB integration with Polars and list_dot_product
function:
duckdb.sql("""
select
values, weights,
cast(list_dot_product(values, weights) as int) as dot
from df
""")
┌───────────┬───────────┬───────┐
│ values │ weights │ dot │
│ int32[] │ int32[] │ int32 │
├───────────┼───────────┼───────┤
│ [6, 7, 3] │ [8, 3, 9] │ 96 │
│ [5, 3, 3] │ [5, 3, 6] │ 52 │
│ [4, 0, 5] │ [9, 7, 0] │ 36 │
└───────────┴───────────┴───────┘
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With