Visualización del comportamiento de la validación cruzada en scikit-learn

La elección del objeto de validación cruzada adecuado es una parte crucial para ajustar un modelo correctamente. Hay muchas formas de dividir los datos en conjuntos de entrenamiento y de prueba para evitar el sobreajuste del modelo, para estandarizar el número de grupos en los conjuntos de prueba, etc.

Este ejemplo visualiza el comportamiento de varios objetos comunes de scikit-learn para su comparación.

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
                                     StratifiedKFold, GroupShuffleSplit,
                                     GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4

Visualizar nuestros datos

En primer lugar, debemos entender la estructura de nuestros datos. Tiene 100 puntos de datos de entrada generados aleatoriamente, 3 clases divididas de forma desigual entre los puntos de datos y 10 «grupos» divididos de forma uniforme entre los puntos de datos.

Como veremos, algunos objetos de validación cruzada hacen cosas específicas con los datos etiquetados, otros se comportan de forma diferente con los datos agrupados y otros no utilizan esta información.

Para empezar, visualizaremos nuestros datos.

# Generate the class/group data
n_points = 100
X = np.random.randn(100, 10)

percentiles_classes = [.1, .3, .6]
y = np.hstack([[ii] * int(100 * perc)
               for ii, perc in enumerate(percentiles_classes)])

# Evenly spaced groups repeated once
groups = np.hstack([[ii] * 10 for ii in range(10)])


def visualize_groups(classes, groups, name):
    # Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(range(len(groups)),  [.5] * len(groups), c=groups, marker='_',
               lw=50, cmap=cmap_data)
    ax.scatter(range(len(groups)),  [3.5] * len(groups), c=classes, marker='_',
               lw=50, cmap=cmap_data)
    ax.set(ylim=[-1, 5], yticks=[.5, 3.5],
           yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index")


visualize_groups(y, groups, 'no groups')
plot cv indices

Definir una función para visualizar el comportamiento de la validación cruzada

Definiremos una función que nos permita visualizar el comportamiento de cada objeto de validación cruzada. Realizaremos 4 divisiones de los datos. En cada división, visualizaremos los índices elegidos para el conjunto de entrenamiento (en azul) y el conjunto de prueba (en rojo).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    # Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        # Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # Visualize the results
        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=cmap_cv,
                   vmin=-.2, vmax=1.2)

    # Plot the data classes and groups at the end
    ax.scatter(range(len(X)), [ii + 1.5] * len(X),
               c=y, marker='_', lw=lw, cmap=cmap_data)

    ax.scatter(range(len(X)), [ii + 2.5] * len(X),
               c=group, marker='_', lw=lw, cmap=cmap_data)

    # Formatting
    yticklabels = list(range(n_splits)) + ['class', 'group']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+2.2, -.2], xlim=[0, 100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
    return ax

Veamos cómo se ve el objeto de validación cruzada KFold:

fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
KFold

Out:

<AxesSubplot:title={'center':'KFold'}, xlabel='Sample index', ylabel='CV iteration'>

Como puedes ver, por defecto el iterador de validación cruzada KFold no tiene en cuenta ni la clase del punto de datos ni el grupo. Podemos cambiar esto utilizando el StratifiedKFold así.

fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
StratifiedKFold

Out:

<AxesSubplot:title={'center':'StratifiedKFold'}, xlabel='Sample index', ylabel='CV iteration'>

En este caso, la validación cruzada mantuvo la misma razón(ratio) de clases en cada división de CV. A continuación, visualizaremos este comportamiento para una serie de iteradores de CV.

Visualizar los índices de validación cruzada para muchos objetos de CV

Vamos a comparar visualmente el comportamiento de la validación cruzada para muchos objetos de validación cruzada de scikit-learn. A continuación, haremos un bucle a través de varios objetos de validación cruzada comunes, visualizando el comportamiento de cada uno.

Ten en cuenta que algunos utilizan la información del grupo/clase mientras que otros no lo hacen.

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
       GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit]


for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
              ['Testing set', 'Training set'], loc=(1.02, .8))
    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=.7)
plt.show()
  • KFold
  • GroupKFold
  • ShuffleSplit
  • StratifiedKFold
  • GroupShuffleSplit
  • StratifiedShuffleSplit
  • TimeSeriesSplit

Tiempo total de ejecución del script: (0 minutos 1.446 segundos)

Galería generada por Sphinx-Gallery