Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Polars - How to compute rolling ewm grouped by column?

What's the right way to perform a group_by + rolling aggregate operation in polars? For some reason performing an ewm_mean over a rolling groupby gives me the list of all the ewm's rolling by time. For example take the dataframe below:

portfolios = pl.from_repr("""
┌─────────────────────┬────────┬───────────┐
│ ts                  ┆ symbol ┆ signal_0  │
│ ---                 ┆ ---    ┆ ---       │
│ datetime[μs]        ┆ str    ┆ f64       │
╞═════════════════════╪════════╪═══════════╡
│ 2022-02-14 09:20:00 ┆ A      ┆ -1.704301 │
│ 2022-02-14 09:20:00 ┆ AA     ┆ -1.181743 │
│ 2022-02-14 09:50:00 ┆ A      ┆ 1.040125  │
│ 2022-02-14 09:50:00 ┆ AA     ┆ 0.776798  │
│ 2022-02-14 10:20:00 ┆ A      ┆ 1.934686  │
│ 2022-02-14 10:20:00 ┆ AA     ┆ 1.480892  │
│ 2022-02-14 10:50:00 ┆ A      ┆ 2.073418  │
│ 2022-02-14 10:50:00 ┆ AA     ┆ 1.623698  │
│ 2022-02-14 11:20:00 ┆ A      ┆ 2.088835  │
│ 2022-02-14 11:20:00 ┆ AA     ┆ 1.741544  │
└─────────────────────┴────────┴───────────┘
""")

Here, I want to group by symbol and get the rolling mean for signal_0 at every timestamp. Unfortunately this doesn't work:

portfolios.rolling("ts", group_by="symbol", period="1d").agg(
    pl.col("signal_0").ewm_mean(half_life=0.1).alias(f"signal_0_mean")
)
shape: (10, 3)
┌────────┬─────────────────────┬─────────────────────────────────┐
│ symbol ┆ ts                  ┆ signal_0_mean                   │
│ ---    ┆ ---                 ┆ ---                             │
│ str    ┆ datetime[μs]        ┆ list[f64]                       │
╞════════╪═════════════════════╪═════════════════════════════════╡
│ A      ┆ 2022-02-14 09:20:00 ┆ [-1.704301]                     │
│ A      ┆ 2022-02-14 09:50:00 ┆ [-1.704301, 1.037448]           │
│ A      ┆ 2022-02-14 10:20:00 ┆ [-1.704301, 1.037448, 1.93381]  │
│ A      ┆ 2022-02-14 10:50:00 ┆ [-1.704301, 1.037448, … 2.0732… │
│ A      ┆ 2022-02-14 11:20:00 ┆ [-1.704301, 1.037448, … 2.0888… │
│ AA     ┆ 2022-02-14 09:20:00 ┆ [-1.181743]                     │
│ AA     ┆ 2022-02-14 09:50:00 ┆ [-1.181743, 0.774887]           │
│ AA     ┆ 2022-02-14 10:20:00 ┆ [-1.181743, 0.774887, 1.480203… │
│ AA     ┆ 2022-02-14 10:50:00 ┆ [-1.181743, 0.774887, … 1.6235… │
│ AA     ┆ 2022-02-14 11:20:00 ┆ [-1.181743, 0.774887, … 1.7414… │
└────────┴─────────────────────┴─────────────────────────────────┘

If I wanted to do this in pandas, I would write:

portfolios.to_pandas().set_index(["ts", "symbol"]).groupby(level=1)["signal_0"].transform(
    lambda x: x.ewm(halflife=10).mean()
)

Which would yield:

ts                   symbol
2022-02-14 09:20:00  A        -1.704301
                     AA       -1.181743
2022-02-14 09:50:00  A        -0.284550
                     AA       -0.168547
2022-02-14 10:20:00  A         0.507021
                     AA        0.419785
2022-02-14 10:50:00  A         0.940226
                     AA        0.752741
2022-02-14 11:20:00  A         1.202843
                     AA        0.978820
Name: signal_0, dtype: float64
like image 320
OneRaynyDay Avatar asked Oct 14 '25 08:10

OneRaynyDay


1 Answers

You were close. Since ewm_mean produces an estimate for each observation in each window, you simply need to specify that you want the last calculated value in each rolling window.

(
    portfolios
    .rolling("ts", group_by="symbol", period="1d")
    .agg(
        pl.col("signal_0").ewm_mean(half_life=10).last().alias(f"signal_0_mean")
    )
    .sort('ts', 'symbol')
)
shape: (10, 3)
┌────────┬─────────────────────┬───────────────┐
│ symbol ┆ ts                  ┆ signal_0_mean │
│ ---    ┆ ---                 ┆ ---           │
│ str    ┆ datetime[μs]        ┆ f64           │
╞════════╪═════════════════════╪═══════════════╡
│ A      ┆ 2022-02-14 09:20:00 ┆ -1.704301     │
│ AA     ┆ 2022-02-14 09:20:00 ┆ -1.181743     │
│ A      ┆ 2022-02-14 09:50:00 ┆ -0.28455      │
│ AA     ┆ 2022-02-14 09:50:00 ┆ -0.168547     │
│ A      ┆ 2022-02-14 10:20:00 ┆ 0.507021      │
│ AA     ┆ 2022-02-14 10:20:00 ┆ 0.419785      │
│ A      ┆ 2022-02-14 10:50:00 ┆ 0.940226      │
│ AA     ┆ 2022-02-14 10:50:00 ┆ 0.752741      │
│ A      ┆ 2022-02-14 11:20:00 ┆ 1.202844      │
│ AA     ┆ 2022-02-14 11:20:00 ┆ 0.97882       │
└────────┴─────────────────────┴───────────────┘