Source code for recsys_metrics_polars.recall
from typing import Optional
import polars as pl
from .base import BaseMetricAtK, BaseRecMetric
from .data_info import DataInfo
from .hit_at_k import compute_hit_at_k
[docs]
class RecallAtK(BaseMetricAtK):
"""
Recall@k
:math:`Recall@k = \\dfrac{\\text{Number of relevant items with rank} \\leq k}{\\min(k, \\text{Total relevenat items for query})}`
"""
_TOTAL_ITEMS_PER_GROUP_COL = "total_items_per_query"
def __init__(self, data_info: DataInfo):
super().__init__(data_info)
self._total_items_per_query: Optional[pl.DataFrame] = None
[docs]
def fit(self, true_interactions: pl.DataFrame, recommendations: pl.DataFrame) -> BaseRecMetric:
super().fit(true_interactions, recommendations)
self._total_items_per_query = (
self._joined_data.group_by(self.data_info.query_id_cols)
.count()
.with_columns(pl.col("count").alias(self._TOTAL_ITEMS_PER_GROUP_COL))
)
return self
def _compute_hit_at_k(self, k: int):
assert self._joined_data is not None, "fit(...) first"
self._joined_data = compute_hit_at_k(self._joined_data, k)
[docs]
def compute_per_query(self, k: int, **kwargs):
self._compute_hit_at_k(k)
hit_metric_name = self._hit_at_k_metric_name(k)
return (
self._joined_data.lazy()
.join(self._total_items_per_query.lazy(), on=self.data_info.query_id_cols)
.group_by(self.data_info.query_id_cols)
.agg(
pl.sum(hit_metric_name).cast(pl.Float32).alias(f"recall@{k}"),
pl.when(pl.max(self._TOTAL_ITEMS_PER_GROUP_COL) < k)
.then(pl.max(self._TOTAL_ITEMS_PER_GROUP_COL))
.otherwise(k)
.alias(self._TOTAL_ITEMS_PER_GROUP_COL),
)
.with_columns(pl.col(f"recall@{k}") / pl.col(self._TOTAL_ITEMS_PER_GROUP_COL))
.select(pl.col("*").exclude(self._TOTAL_ITEMS_PER_GROUP_COL))
.collect()
)
[docs]
def avergae_over_queries(self, k: int, **kwargs):
metric_per_query = self.compute_per_query(k=k)
return metric_per_query.get_column(f"recall@{k}").mean()