Source code for scanpex.pl.ml_evaluation._pr

import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve


class MulticlassPR:
    """
    Compute Precision-Recall curve components and AP scores for multi-class models.

    This class calculates the Recall (x-axis), Precision (y-axis), thresholds,
    and Average Precision (AP) for each class individually. It also calculates
    the baseline precision (prevalence) for each class.

    Attributes
    ----------
    x : list of np.ndarray
        A list containing the Recall values for each class.
    y : list of np.ndarray
        A list containing the Precision values for each class.
    thresh : list of list of float
        A list containing the thresholds used to compute the PR curve.
    ap : list of float
        A list containing the Average Precision (AP) score for each class.
    base : list of float
        A list containing the prevalence (fraction of positives) for each class.
        This represents the performance of a random classifier ("no-skill" line).
    """

    def __init__(self, model, x, y):
        """
        Initialize the MulticlassPR calculator.

        Parameters
        ----------
        model : object
            The trained model. Its `predict` method should return scores or
            probabilities suitable for curve calculation.
        x : np.ndarray
            The feature matrix.
        y : np.ndarray
            The ground truth labels (one-hot encoded) of shape (n_samples, n_classes).
        """
        self.x = [
            precision_recall_curve(y[:, i], model.predict(x)[:, i])[1]
            for i in range(y.shape[1])
        ]
        self.y = [
            precision_recall_curve(y[:, i], model.predict(x)[:, i])[0]
            for i in range(y.shape[1])
        ]
        self.thresh = [
            precision_recall_curve(y[:, i], model.predict(x)[:, i])[2].tolist()
            for i in range(y.shape[1])
        ]
        self.ap = [
            average_precision_score(y[:, i], model.predict(x)[:, i])
            for i in range(y.shape[1])
        ]
        self.base = [len(y[:, i][y[:, i] == 1]) / len(y) for i in range(y.shape[1])]


[docs] def multiclass_pr(model, x, y, ax, cmap, label_dict, minimalist: bool = False): """ Plot One-vs-Rest Precision-Recall curves and the micro-average curve. This function visualizes the trade-off between Precision and Recall. It plots: 1. Individual curves for each class. 2. A "micro-average" curve calculated by aggregating all class predictions. 3. Baseline lines representing the prevalence of each class. 4. An optional "ideal" line (Precision=1.0). Parameters ---------- model : object The trained classification model. x : np.ndarray The input features. y : np.ndarray The ground truth labels (one-hot encoded). ax : matplotlib.axes.Axes The axis on which to draw the plot. cmap : str or list The color mapping strategy. - If str: Name of a matplotlib colormap. - If list: Specific colors for each class. label_dict : list of str The names of the classes corresponding to the columns of `y`. minimalist : bool, optional If True, hides the "ideal" line (perfect precision) to reduce clutter. By default False. Returns ------- None The plot is drawn directly onto the provided `ax` object. """ pr = MulticlassPR(model=model, x=x, y=y) for (i, x_), y_, ap, label, base in zip( enumerate(pr.x), pr.y, pr.ap, label_dict, pr.base ): ax.plot( x_, y_, label=f"{label} (AP:{ap.round(3)})", c=( eval(f"plt.cm.{cmap}")(i / (len(pr.ap) - 1)) if isinstance(cmap, str) else cmap[i] ), ) ax.plot( [0, 1], [base, base], c="gray", linestyle=(0, (1, 2)), zorder=1, alpha=0.5 ) p_avg, r_avg, thr_avg = precision_recall_curve(y.ravel(), model.predict(x).ravel()) ax.plot( r_avg, p_avg, c=".2", label=f"micro (AP:{average_precision_score(y, model.predict(x)).round(2)})", ) base = np.array(pr.base).mean() baseline_name = None if not minimalist: ax.plot([0, 1], [1, 1], linestyle=(0, (1, 2)), c=".2", label="ideal", zorder=2) baseline_name = "baselines" ax.plot( [0, 1], [base, base], c="gray", linestyle=(0, (1, 2)), label=baseline_name, zorder=1, alpha=0.5, ) ax.set(xlabel="recall", ylabel="precision", title="PR curve (OvR)") ax.legend(fontsize="small", frameon=False)