Given a Polars DataFrame
data = pl.DataFrame({"user_id": [1, 1, 1, 1, 1, 2, 2, 2, 2], "event": [False, True, True, False, True, True, True, False, False]
I wish to calculate a column event_chain
which counts the streak of times where a user has an event, where in any of the previous 4 rows they also had an event. Every time a new event happens, when the user already has a streak active, the streak counter is incremented, it should be then set to zero if they don't have another event for another 4 rows
user_id | event | event_chain | reason for value |
---|---|---|---|
1 | False | 0 | no events yet |
1 | True | 0 | No events in last 4 rows (not inclusive of current row) |
1 | True | 1 | event this row, and 1 event in last 4 rows |
1 | False | 1 | Does not reset to 0 as there is an event within the next 4 rows |
1 | True | 2 | event this row and event last 4 rows, increment the streak |
2 | True | 0 | No previous events |
2 | True | 1 | Event this row and in last 4 rows for user |
2 | False | 0 | No event this row, and no events in next 4 rows for user, resets to 0 |
2 | False | 0 |
I have working code as follows to do this, but I think there should be a cleaner way to do it
data.with_columns( rows_since_last_event=pl.int_range(pl.len()).over("user_id") - pl.when("event").then(pl.int_range(pl.len())).forward_fill() .over("user_id"), rows_till_next_event=pl.when("event").then(pl.int_range(pl.len())) .backward_fill().over("user_id") - pl.int_range(pl.len()).over("athlete_id") ) .with_columns( chain_event=pl.when( pl.col("event") .fill_null(0) .rolling_sum(window_size=4, min_periods=1) .over("user_id") - pl.col("event").fill_null(0)> 0 ) .then(1) .otherwise(0) ) .with_columns( chain_event_change=pl.when( pl.col("chain_event").eq(1), pl.col("chain_event").shift().eq(0), pl.col("rows_since_last_event").fill_null(5) > 3, ) .then(1) .when( pl.col("congested_event").eq(0), pl.col("congested_event").shift().eq(1), pl.col("rows_till_next_event").fill_null(5) > 3, ) .then(1) .otherwise(0) ) .with_columns( chain_event_identifier=pl.col("chain_event_change") .cum_sum() .over("user_id") ) .with_columns( event_chain=pl.col("chain_event") .cum_sum() .over("user_id", "chain_event_identifier") ) )