Desarrollo con la API Plotting

Scikit-learn define una API sencilla para crear visualizaciones para el aprendizaje automático. Las características clave de esta API son ejecutar los cálculos una vez y tener la flexibilidad de ajustar las visualizaciones después del hecho. Esta sección está destinada a los desarrolladores que deseen desarrollar o mantener herramientas de trazado. Para su uso, los usuarios deben consultar el Manual de Usuario.

Visión general de la API Plotting

Esta lógica se encapsula en un objeto de visualización donde se almacenan los datos calculados y el trazado se realiza en un método plot. El método __init__ del objeto de visualización contiene sólo los datos necesarios para crear la visualización. El método plot toma parámetros que sólo tienen que ver con la visualización, como los ejes de matplotlib. El método plot almacenará los artistas de matplotlib como atributos que permiten ajustes de estilo a través del objeto de visualización. Una función de ayuda (helper) plot_* acepta parámetros para hacer el cálculo y los parámetros utilizados para el trazado. Después de que la función de ayuda crea el objeto de visualización con los valores calculados, invoca al método plot de la visualización. Ten en cuenta que el método plot define atributos relacionados con matplotlib, como el artista de la línea. Esto permite personalizaciones después de invocar al método plot.

Por ejemplo, RocCurveDisplay define los siguientes métodos y atributos:

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
        self.fpr = fpr
        self.tpr = tpr
        self.roc_auc = roc_auc
        self.estimator_name = estimator_name

    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
                   drop_intermediate=True, response_method="auto",
                   name=None, ax=None, **kwargs):
    # do computation
    viz = RocCurveDisplay(fpr, tpr, roc_auc,
                             estimator.__class__.__name__)
    return viz.plot(ax=ax, name=name, **kwargs)

Lee más en Graficación Avanzada Con Dependencia Parcial y en el Manual de Usuario.

Graficar con Múltiples Ejes

Algunas de las herramientas de graficación como plot_partial_dependence y PartialDependenceDisplay admiten el trazado en múltiples ejes. Se admiten dos escenarios diferentes:

Primero, si se pasa una lista de ejes, plot comprobará si el número de ejes es coherente con el número de ejes que espera y entonces dibujará en esos ejes. Segundo, si se pasa un solo eje, ese eje define un espacio para colocar varios ejes. En este caso, te sugerimos que utilices la función ~matplotlib.gridspec.GridSpecFromSubplotSpec de matplotlib para dividir el espacio:

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec

fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())

ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])

Por defecto, la palabra clave ax en plot es None. En este caso, se crea el eje único y se utiliza la api gridspec para crear las regiones a trazar.

Mira por ejemplo, plot_partial_dependence que traza múltiples líneas y contornos utilizando esta API. Los ejes que definen la caja delimitadora se guardan en un atributo bounding_ax_. Los ejes individuales creados se guardan en un ndarray axes_ , correspondiente a la posición de los ejes en la cuadrícula. Las posiciones que no se utilizan se establecen en None. Además, los Artistas de matplotlib se almacenan en lines_ y contours_ donde la clave es la posición en la cuadrícula. Cuando se pasa una lista de ejes, el axes_, lines_ y contours_ es un ndarray 1d correspondiente a la lista de ejes pasados.