Source code for recsys_metrics_polars.precision
import polars as pl
from .base import BaseMetricAtK
from .hit_at_k import compute_hit_at_k
[docs]
class PrecisionAtK(BaseMetricAtK):
"""
Precision@k
:math:`Precision@k = \\dfrac{\\text{Number of relevant items with rank} \\leq k}{k}`
"""
def _compute_hit_at_k(self, k: int):
assert self._joined_data is not None, "fit(...) first"
self._joined_data = compute_hit_at_k(self._joined_data, k)
[docs]
def compute_per_query(self, k: int, **kwargs):
self._compute_hit_at_k(k)
hit_metric_name = self._hit_at_k_metric_name(k)
return (
self._joined_data.lazy()
.group_by(self.data_info.query_id_cols)
.agg((pl.sum(hit_metric_name).cast(pl.Float32) / k).alias(f"prec@{k}"))
.collect()
)
[docs]
def avergae_over_queries(self, k: int, **kwargs) -> float:
self._compute_hit_at_k(k)
hit_metric_name = self._hit_at_k_metric_name(k)
return self._joined_data.select(pl.col(hit_metric_name).sum().cast(pl.Float32) / k)[
0, 0
] / self._joined_data.n_unique(self.data_info.query_id_cols)