Ejemplo de regresión lineal

El ejemplo siguiente utiliza sólo la primera característica del conjunto de datos diabetes, para ilustrar los puntos de datos dentro del gráfico bidimensional. La línea recta puede verse en el gráfico, mostrando cómo la regresión lineal intenta dibujar una línea recta que minimice de la mejor manera posible la suma residual de cuadrados entre las respuestas observadas en el conjunto de datos y las respuestas predichas por la aproximación lineal.

Los coeficientes, la suma de cuadrados residuales y el coeficiente de determinación también se calculan.

plot ols

Out:

Coefficients:
 [938.23786125]
Mean squared error: 2548.07
Coefficient of determination: 0.47

print(__doc__)


# Code source: Jaques Grobler
# License: BSD 3 clause


import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score

# Load the diabetes dataset
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)

# Use only one feature
diabetes_X = diabetes_X[:, np.newaxis, 2]

# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]

# Split the targets into training/testing sets
diabetes_y_train = diabetes_y[:-20]
diabetes_y_test = diabetes_y[-20:]

# Create linear regression object
regr = linear_model.LinearRegression()

# Train the model using the training sets
regr.fit(diabetes_X_train, diabetes_y_train)

# Make predictions using the testing set
diabetes_y_pred = regr.predict(diabetes_X_test)

# The coefficients
print('Coefficients: \n', regr.coef_)
# The mean squared error
print('Mean squared error: %.2f'
      % mean_squared_error(diabetes_y_test, diabetes_y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'
      % r2_score(diabetes_y_test, diabetes_y_pred))

# Plot outputs
plt.scatter(diabetes_X_test, diabetes_y_test,  color='black')
plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)

plt.xticks(())
plt.yticks(())

plt.show()

Tiempo total de ejecución del script: ( 0 minutos 0.053 segundos)

Galería generada por Sphinx-Gallery