Reconocimiento de dígitos escritos a mano

Este ejemplo muestra cómo scikit-learn puede utilizarse para reconocer imágenes de dígitos escritos a mano, del 0 al 9.

print(__doc__)

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause

# Standard scientific Python imports
import matplotlib.pyplot as plt

# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split

Conjunto de datos de dígitos

El conjunto de datos de dígitos consiste en imágenes de 8x8 píxeles de dígitos. El atributo images del conjunto de datos almacena matrices de 8x8 de valores en escala de grises para cada imagen. Utilizaremos estos arreglos para visualizar las 4 primeras imágenes. El atributo target del conjunto de datos almacena el dígito que representa cada imagen y que se incluye en el título de los 4 gráficos siguientes.

Nota: si estuviéramos trabajando a partir de archivos de imagen (por ejemplo, archivos “png”), los cargaríamos utilizando matplotlib.pyplot.imread.

digits = datasets.load_digits()

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
    ax.set_axis_off()
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    ax.set_title('Training: %i' % label)
Training: 0, Training: 1, Training: 2, Training: 3

Clasificación

Para aplicar un clasificador a estos datos, tenemos que aplanar las imágenes, convirtiendo cada arreglo 2-D de valores en escala de grises de la forma (8, 8) en la forma (64,). Posteriormente, el conjunto de datos tendrá la forma (n_samples, n_features), donde n_samples es el número de imágenes y n_características es el número total de píxeles de cada imagen.

A continuación, podemos dividir los datos en subconjuntos de entrenamiento y de prueba y ajustar un clasificador de vectores de soporte a las muestras de entrenamiento. El clasificador ajustado puede utilizarse posteriormente para predecir el valor del dígito para las muestras del subconjunto de prueba.

# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)

# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
    data, digits.target, test_size=0.5, shuffle=False)

# Learn the digits on the train subset
clf.fit(X_train, y_train)

# Predict the value of the digit on the test subset
predicted = clf.predict(X_test)

A continuación, visualizamos las 4 primeras muestras de prueba y mostramos su valor de dígito predicho en el título.

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, prediction in zip(axes, X_test, predicted):
    ax.set_axis_off()
    image = image.reshape(8, 8)
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    ax.set_title(f'Prediction: {prediction}')
Prediction: 8, Prediction: 8, Prediction: 4, Prediction: 9

classification_report construye un informe de texto que muestra las principales métricas de clasificación.

print(f"Classification report for classifier {clf}:\n"
      f"{metrics.classification_report(y_test, predicted)}\n")

Out:

Classification report for classifier SVC(gamma=0.001):
              precision    recall  f1-score   support

           0       1.00      0.99      0.99        88
           1       0.99      0.97      0.98        91
           2       0.99      0.99      0.99        86
           3       0.98      0.87      0.92        91
           4       0.99      0.96      0.97        92
           5       0.95      0.97      0.96        91
           6       0.99      0.99      0.99        91
           7       0.96      0.99      0.97        89
           8       0.94      1.00      0.97        88
           9       0.93      0.98      0.95        92

    accuracy                           0.97       899
   macro avg       0.97      0.97      0.97       899
weighted avg       0.97      0.97      0.97       899

También podemos trazar una matriz de confusión de los valores de los dígitos verdaderos y los valores de los dígitos predichos.

disp = metrics.plot_confusion_matrix(clf, X_test, y_test)
disp.figure_.suptitle("Confusion Matrix")
print(f"Confusion matrix:\n{disp.confusion_matrix}")

plt.show()
Confusion Matrix

Out:

Confusion matrix:
[[87  0  0  0  1  0  0  0  0  0]
 [ 0 88  1  0  0  0  0  0  1  1]
 [ 0  0 85  1  0  0  0  0  0  0]
 [ 0  0  0 79  0  3  0  4  5  0]
 [ 0  0  0  0 88  0  0  0  0  4]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  1  0  0  0  0 90  0  0  0]
 [ 0  0  0  0  0  1  0 88  0  0]
 [ 0  0  0  0  0  0  0  0 88  0]
 [ 0  0  0  1  0  1  0  0  0 90]]

Tiempo total de ejecución del script: (0 minutos 0.683 segundos)

Galería generada por Sphinx-Gallery