Source code for hundred_hammers.metric_alias
"""
This module provides some alternative names for metrics
implemented in sklearn.
"""
from __future__ import annotations
from typing import Tuple
from sklearn.metrics import get_scorer
metric_alias = {
"ACC": "accuracy",
"BACC": "balanced_accuracy",
"PREC": "precision",
"PRECW": "precision_weighted",
"REC": "recall",
"F1": "f1",
"F1W": "f1_weighted",
"ROC": "roc_auc",
"LogLoss": "neg_log_loss",
"MAE": "neg_mean_absolute_error",
"RMSE": "neg_root_mean_squared_error",
"MAPE": "neg_mean_absolute_percentage_error",
"MSE": "neg_mean_squared_error",
"R2": "r2",
}
[docs]
def process_metric(metric: str | callable, metric_params: dict = None) -> Tuple[str, callable, dict]:
"""
Converts a metric into a tuple with the name, function call and its parameters
:param metric: a string or callable that represents the error function
"""
if isinstance(metric, str):
# Metric given by its name
metric_fn_name = metric
if metric in metric_alias:
metric_fn_name = metric_alias[metric]
scorer = get_scorer(metric_fn_name)
name = metric
metric_fn = scorer._score_func
metric_params = scorer._kwargs if metric_params is None else metric_params
else:
# Metric given as a lambda function
name = metric.__name__
metric_fn = metric
metric_params = {} if metric_params is None else metric_params
return (name, metric_fn, metric_params)