Source code for hundred_hammers.plots

import numpy as np
import seaborn as sns
from adjustText import adjust_text
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from .metric_alias import metric_alias


def _save_plot(filepath, display, dpi=200):
    if filepath is not None:
        plt.savefig(filepath, dpi=dpi)
        if not display:
            plt.close()
    else:
        if display:
            plt.show()


[docs] def plot_confusion_matrix(X, y, model, class_dict, title="", test_size=0.2, seed=0, filepath=None, display=True): """ Plot confusion matrix for a given model. :param X: input observations :param y: target values :param model: model to evaluate :param class_dict: dictionary with class names (ex.: {0: "No", 1: "Yes"}) :param title: title of the plot :param test_size: percentage of the dataset to use for testing :param seed: random seed :param filepath: path to save the plot :param display: whether to display the plot """ # Setting up labels = [class_dict[f] for f in np.unique(y)] X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=test_size, random_state=seed) # Plotting sns.set_style("ticks") fig, axs = plt.subplots(2, 2, figsize=(8, 8)) fig.patch.set_facecolor("lightgrey") for i, md in enumerate([model, DummyClassifier()]): md.fit(X_train, y_train) cm_train = confusion_matrix(y_train, md.predict(X_train)) cm_test = confusion_matrix(y_test, md.predict(X_test)) for j, (cm_name, cm, _X, _y) in enumerate([("Train", cm_train, X_train, y_train), ("Test", cm_test, X_test, y_test)]): ax = axs[i, j] percentages = cm.astype("float") / cm.sum() annots = [[f"{v}\n\n({(p*100):.2f})" for v, p in zip(vs, ps)] for vs, ps in zip(cm, percentages)] cmap = sns.color_palette(["light:#7A7", "light:#77B"][j], as_cmap=True) sns.heatmap( cm, annot=annots, fmt="", square=True, xticklabels=labels, ax=ax, yticklabels=labels, linewidths=1, cmap=cmap, cbar=False, linecolor="black", ) md_name = md.__class__.__name__ acc = md.score(_X, _y) ax_title = f"{md_name} // {cm_name} data\n(N: {len(_y)}, Acc: {acc:.3f})" ax.set_title(ax_title, fontweight="bold") ax.set_xlabel("Predicted", fontweight="bold") ax.set_ylabel("Actual", fontweight="bold") plt.suptitle(title, fontweight="bold", fontsize=14) plt.tight_layout() _save_plot(filepath, display)
[docs] def plot_regression_pred(X, y, models, y_label="", title="", test_size=0.2, metric=None, seed=0, filepath=None, display=True): """ Plot the predictions of the regression model :param X: input observations :param y: target values :param models: list of models to evaluate :param y_label: name of the target variable :param title: title of the plot :param test_size: percentage of the dataset to use for testing :param metric: metric to use for evaluation :param seed: random seed :param filepath: path to save the plot :param display: whether to display the plot """ # Preparing data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed) _s = np.argsort(y_train) X_train, y_train = X_train[_s], y_train[_s] _s = np.argsort(y_test) X_test, y_test = X_test[_s], y_test[_s] _X = np.concatenate((X_train, X_test)) _y = np.concatenate((y_train, y_test)) # Plotting sns.set_style("ticks") fig = plt.figure(figsize=(10, 6)) fig.patch.set_facecolor("lightgrey") plt.plot(range(len(_y)), _y, "b", label="Data", linewidth=3) for md in models: md = md.fit(X_train, y_train) md_name = md.__class__.__name__ res = metric(y_test, md.predict(X_test)) if metric is not None else "" metric_name = metric.__name__ if metric is not None else "" label = f"{md_name}" + (f"\n({metric_name}: {res:.2f})" if metric is not None else "") plt.plot(range(len(_y)), md.predict(_X), label=label) plt.axvline(len(y_train) - 0.5, color="k", linestyle="--", linewidth=3) plt.text(len(y_train) / 2, _y.max() * 0.9, "Train data", ha="center", va="center", fontsize=20) plt.text(len(y_train) + len(y_test) / 2, _y.max() * 0.9, "Test data", ha="center", va="center", fontsize=20) plt.gca().add_patch(Rectangle((len(y_train) - 0.5, _y.min()), len(y_test), (_y.max() - _y.min()), fill=True, alpha=0.1, color="b")) plt.title(title, fontweight="bold") plt.xlabel("Samples") plt.ylabel(y_label) plt.ylim(min(_y), max(_y)) plt.xlim(0, len(_y)) plt.legend() plt.grid(True) plt.tight_layout() _save_plot(filepath, display)
[docs] def plot_batch_results(df, metric_name, title="", filepath=None, display=True): """ Plot the results of the batch evaluation :param df: results dataframe :param title: title of plot :param filepath: filepath to save plot :param display: whether to display the plot """ # sns.set_context("paper") plt.figure(figsize=(6, 6)) ax = plt.gca() y_label = f"{metric_name} / Validation Test / Mean" x_label = f"{metric_name} / Validation Train / Mean" baseline_result = df[df["Model"].str.startswith("Dummy")][y_label].values[0] sns.scatterplot(data=df, x=x_label, y=y_label, s=100, hue="Model", legend=False) ax.axhline(baseline_result, color="grey", linestyle="--", linewidth=2, zorder=-1) ax.set_facecolor("#eeeeee") texts = [] for name, row in df.iterrows(): text = plt.text(row[x_label] + 0.005, row[y_label] + 0.001, row["Model"], fontsize=12) texts.append(text) adjust_text(texts, arrowprops=dict(arrowstyle="-", color="k", lw=0.5), force_text=1.0, force_points=1.0) plt.title(title, fontweight="bold") plt.grid() plt.tight_layout() _save_plot(filepath, display)
[docs] def plot_multiple_datasets(df, metric_name, id_col="Code", title="", line_at_0=False, higher_is_better=True, filepath=None, display=True): """ Plot the results of the batch evaluation :param df: results dataframe :param metric_name: metric to plot :param id_col: column containing the ID of the dataset :param title: title of plot :param line_at_0: determines if a line is plotted at 0 :param higher_is_better: determines if higher values are better :param filepath: filepath to save plot :param display: whether to display the plot """ _df = df.sort_values(by=metric_name, ascending=(not higher_is_better)) plt.figure(figsize=(10, 8)) ax = plt.gca() sns.scatterplot(data=_df, x=id_col, y=metric_name, s=100, hue=id_col, legend=False) ax.set_xlabel("") ax.set_ylabel(metric_name) if line_at_0: ax.axhline(0, lw=4, color="k", zorder=-1) for tick in ax.get_xticklabels(): tick.set_rotation(90) ax.grid() ax.set_facecolor("#eeeeee") plt.suptitle(title, fontweight="bold") plt.tight_layout() _save_plot(filepath, display)