Source code for recsys_metrics_polars.hit_at_k
import polars as pl
from .constants import RANK_COL
[docs]
def compute_hit_at_k(joined_true_with_recs: pl.DataFrame, k: int):
"""Compute hit at k
:param joined_true_with_recs: _description_
:param k: _description_
:return: _description_
"""
assert k > 0, "k must be psotive"
assert RANK_COL in joined_true_with_recs.columns, f"Cannot find '{RANK_COL}' in the list of columns"
metric_name = f"hit@{k}"
if metric_name not in joined_true_with_recs.columns:
joined_true_with_recs = joined_true_with_recs.with_columns(
(pl.col(RANK_COL) <= k).fill_null(False).alias(metric_name)
)
return joined_true_with_recs