Source code for recsys_metrics_polars.avg_precision

import polars as pl

from .base import BaseMetricAtK
from .constants import RANK_COL


[docs] class AvgPrecisionAtK(BaseMetricAtK): """Average precision and MAP (see :py:meth:`.AvgPrecisionAtK.avergae_over_queries`) AP@k .. math:: & AP@k = \\dfrac{\\sum\\limits_{i=1}^{k} Precision@i \\cdot Rel(i)}{\\min(k, \\text{Total relevant items})}, \\\\ & Rel(i) = \\begin{cases} 1, \\text{item with rank } i \\text{ is relevant,} \\\\ 0, \\text{otherwise} \\end{cases} """
[docs] def compute_per_query(self, k: int, **kwargs) -> pl.DataFrame: lazy_data = self._joined_data.lazy() # avg_prec(k) = sum_{i=1}^{i=min(k, number_of_relevant_items)} hit_at_rank(i) / i * relevant(i) # is is denum precision_at = ( lazy_data.select(self.data_info.query_id_cols) .with_columns( (pl.col(self.data_info.query_id_cols[0]).cumcount().over(self.data_info.query_id_cols) + 1) .alias("prec_k") .cast(self._joined_data.schema[RANK_COL]) ) .filter(pl.col("prec_k") <= k) .collect() ) group_cols = self.data_info.query_id_cols + [RANK_COL] # hit_at_rank(i) * relevant(i) hit_at_rank = ( precision_at.lazy() .join( lazy_data.group_by(group_cols).agg(pl.count(RANK_COL).alias("hit_at_k")), left_on=self.data_info.query_id_cols + ["prec_k"], right_on=group_cols, how="left", ) .with_columns( pl.col("hit_at_k").fill_null(0), pl.when(pl.col("hit_at_k").is_null()) .then(pl.lit(0, dtype=pl.Int8)) .otherwise(pl.lit(1, dtype=pl.Int8)) .alias("is_relevant_at_k"), ) .sort("prec_k") .with_columns(pl.col("hit_at_k").cumsum().over(self.data_info.query_id_cols).cast(pl.Float32)) ) return ( hit_at_rank.group_by(self.data_info.query_id_cols) .agg( ((pl.col("hit_at_k") / pl.col("prec_k") * pl.col("is_relevant_at_k")).sum() / pl.count()).alias( f"avg_prec@{k}" ) ) .collect() )
[docs] def avergae_over_queries(self, k: int, **kwargs) -> float: """Compute MAP@k .. math:: MAP@k = \\dfrac{\\sum\\limits_{i=1}^N AP_i@k}{N}, :math:`AP_i@k`- average precision for query :math:`i`, :math:`i \\in {1,2,\\ldots,N}, N-` total number of queries. """ avg_per_query = self.compute_per_query(k=k) return avg_per_query.get_column(f"avg_prec@{k}").mean()