Source code for recsys_metrics_polars.base
from abc import ABC, abstractmethod
from typing import Optional
import polars as pl
from .data_info import DataInfo
from .utils import join_true_recs_and_preprocess
[docs]
class BaseRecMetric(ABC):
"""Base class for metrics"""
def __init__(self, data_info: DataInfo):
self.data_info = data_info
self._joined_data: Optional[pl.DataFrame] = None
[docs]
def fit(self, true_interactions: pl.DataFrame, recommendations: pl.DataFrame) -> "BaseRecMetric":
"""Prepare data for metric computing
:param true_interactions: true interactions
:param recommendations: predicted interactions with scores for each pair query and item
"""
self._joined_data = join_true_recs_and_preprocess(true_interactions, recommendations, self.data_info)
return self
def _hit_at_k_metric_name(self, k: int) -> str:
return f"hit@{k}"
[docs]
@abstractmethod
def compute_per_query(self, **kwargs) -> pl.DataFrame:
"""Compute metric per query"""
pass
[docs]
@abstractmethod
def avergae_over_queries(self, **kwargs) -> float:
"""Compute mean metric value over all queries"""
pass