Source code for recsys_metrics_polars.base

from abc import ABC, abstractmethod
from typing import Optional

import polars as pl

from .data_info import DataInfo
from .utils import join_true_recs_and_preprocess


[docs] class BaseRecMetric(ABC): """Base class for metrics""" def __init__(self, data_info: DataInfo): self.data_info = data_info self._joined_data: Optional[pl.DataFrame] = None
[docs] def fit(self, true_interactions: pl.DataFrame, recommendations: pl.DataFrame) -> "BaseRecMetric": """Prepare data for metric computing :param true_interactions: true interactions :param recommendations: predicted interactions with scores for each pair query and item """ self._joined_data = join_true_recs_and_preprocess(true_interactions, recommendations, self.data_info) return self
def _hit_at_k_metric_name(self, k: int) -> str: return f"hit@{k}"
[docs] @abstractmethod def compute_per_query(self, **kwargs) -> pl.DataFrame: """Compute metric per query""" pass
[docs] @abstractmethod def avergae_over_queries(self, **kwargs) -> float: """Compute mean metric value over all queries""" pass
[docs] class BaseMetricAtK(BaseRecMetric):
[docs] @abstractmethod def compute_per_query(self, k: int, **kwargs) -> pl.DataFrame: pass
[docs] @abstractmethod def avergae_over_queries(self, k: int, **kwargs) -> float: pass