.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/compose/plot_transformed_target.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_compose_plot_transformed_target.py: ====================================================== Effect of transforming the targets in regression model ====================================================== In this example, we give an overview of :class:`~sklearn.compose.TransformedTargetRegressor`. We use two examples to illustrate the benefit of transforming the targets before learning a linear regression model. The first example uses synthetic data while the second example is based on the Ames housing data set. .. GENERATED FROM PYTHON SOURCE LINES 15-30 .. code-block:: default # Author: Guillaume Lemaitre # License: BSD 3 clause import numpy as np import matplotlib import matplotlib.pyplot as plt from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.linear_model import RidgeCV from sklearn.compose import TransformedTargetRegressor from sklearn.metrics import median_absolute_error, r2_score from sklearn.utils.fixes import parse_version .. GENERATED FROM PYTHON SOURCE LINES 31-33 Synthetic example ############################################################################# .. GENERATED FROM PYTHON SOURCE LINES 33-40 .. code-block:: default # `normed` is being deprecated in favor of `density` in histograms if parse_version(matplotlib.__version__) >= parse_version('2.1'): density_param = {'density': True} else: density_param = {'normed': True} .. GENERATED FROM PYTHON SOURCE LINES 41-52 A synthetic random regression dataset is generated. The targets ``y`` are modified by: 1. translating all targets such that all entries are non-negative (by adding the absolute value of the lowest ``y``) and 2. applying an exponential function to obtain non-linear targets which cannot be fitted using a simple linear model. Therefore, a logarithmic (`np.log1p`) and an exponential function (`np.expm1`) will be used to transform the targets before training a linear regression model and using it for prediction. .. GENERATED FROM PYTHON SOURCE LINES 52-57 .. code-block:: default X, y = make_regression(n_samples=10000, noise=100, random_state=0) y = np.expm1((y + abs(y.min())) / 200) y_trans = np.log1p(y) .. GENERATED FROM PYTHON SOURCE LINES 58-60 Below we plot the probability density functions of the target before and after applying the logarithmic functions. .. GENERATED FROM PYTHON SOURCE LINES 60-79 .. code-block:: default f, (ax0, ax1) = plt.subplots(1, 2) ax0.hist(y, bins=100, **density_param) ax0.set_xlim([0, 2000]) ax0.set_ylabel('Probability') ax0.set_xlabel('Target') ax0.set_title('Target distribution') ax1.hist(y_trans, bins=100, **density_param) ax1.set_ylabel('Probability') ax1.set_xlabel('Target') ax1.set_title('Transformed target distribution') f.suptitle("Synthetic data", y=0.06, x=0.53) f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) .. image:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png :alt: Synthetic data, Target distribution, Transformed target distribution :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 80-85 At first, a linear model will be applied on the original targets. Due to the non-linearity, the model trained will not be precise during prediction. Subsequently, a logarithmic function is used to linearize the targets, allowing better prediction even with a similar linear model as reported by the median absolute error (MAE). .. GENERATED FROM PYTHON SOURCE LINES 85-121 .. code-block:: default f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) # Use linear model regr = RidgeCV() regr.fit(X_train, y_train) y_pred = regr.predict(X_test) # Plot results ax0.scatter(y_test, y_pred) ax0.plot([0, 2000], [0, 2000], '--k') ax0.set_ylabel('Target predicted') ax0.set_xlabel('True Target') ax0.set_title('Ridge regression \n without target transformation') ax0.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) ax0.set_xlim([0, 2000]) ax0.set_ylim([0, 2000]) # Transform targets and use same linear model regr_trans = TransformedTargetRegressor(regressor=RidgeCV(), func=np.log1p, inverse_func=np.expm1) regr_trans.fit(X_train, y_train) y_pred = regr_trans.predict(X_test) ax1.scatter(y_test, y_pred) ax1.plot([0, 2000], [0, 2000], '--k') ax1.set_ylabel('Target predicted') ax1.set_xlabel('True Target') ax1.set_title('Ridge regression \n with target transformation') ax1.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) ax1.set_xlim([0, 2000]) ax1.set_ylim([0, 2000]) f.suptitle("Synthetic data", y=0.035) f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) .. image:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png :alt: Synthetic data, Ridge regression without target transformation, Ridge regression with target transformation :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 122-128 Real-world data set ############################################################################## In a similar manner, the Ames housing data set is used to show the impact of transforming the targets before learning a model. In this example, the target to be predicted is the selling price of each house. .. GENERATED FROM PYTHON SOURCE LINES 128-142 .. code-block:: default from sklearn.datasets import fetch_openml from sklearn.preprocessing import QuantileTransformer, quantile_transform ames = fetch_openml(name="house_prices", as_frame=True) # Keep only numeric columns X = ames.data.select_dtypes(np.number) # Remove columns with NaN or Inf values X = X.drop(columns=['LotFrontage', 'GarageYrBlt', 'MasVnrArea']) y = ames.target y_trans = quantile_transform(y.to_frame(), n_quantiles=900, output_distribution='normal', copy=True).squeeze() .. GENERATED FROM PYTHON SOURCE LINES 143-146 A :class:`~sklearn.preprocessing.QuantileTransformer` is used to normalize the target distribution before applying a :class:`~sklearn.linear_model.RidgeCV` model. .. GENERATED FROM PYTHON SOURCE LINES 146-165 .. code-block:: default f, (ax0, ax1) = plt.subplots(1, 2) ax0.hist(y, bins=100, **density_param) ax0.set_ylabel('Probability') ax0.set_xlabel('Target') ax0.text(s='Target distribution', x=1.2e5, y=9.8e-6, fontsize=12) ax0.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) ax1.hist(y_trans, bins=100, **density_param) ax1.set_ylabel('Probability') ax1.set_xlabel('Target') ax1.text(s='Transformed target distribution', x=-6.8, y=0.479, fontsize=12) f.suptitle("Ames housing data: selling price", y=0.04) f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) .. image:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_003.png :alt: Ames housing data: selling price :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 166-173 The effect of the transformer is weaker than on the synthetic data. However, the transformation results in an increase in :math:`R^2` and large decrease of the MAE. The residual plot (predicted target - true target vs predicted target) without target transformation takes on a curved, 'reverse smile' shape due to residual values that vary depending on the value of predicted target. With target transformation, the shape is more linear indicating better model fit. .. GENERATED FROM PYTHON SOURCE LINES 173-224 .. code-block:: default f, (ax0, ax1) = plt.subplots(2, 2, sharey='row', figsize=(6.5, 8)) regr = RidgeCV() regr.fit(X_train, y_train) y_pred = regr.predict(X_test) ax0[0].scatter(y_pred, y_test, s=8) ax0[0].plot([0, 7e5], [0, 7e5], '--k') ax0[0].set_ylabel('True target') ax0[0].set_xlabel('Predicted target') ax0[0].text(s='Ridge regression \n without target transformation', x=-5e4, y=8e5, fontsize=12, multialignment='center') ax0[0].text(3e4, 64e4, r'$R^2$=%.2f, MAE=%.2f' % ( r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) ax0[0].set_xlim([0, 7e5]) ax0[0].set_ylim([0, 7e5]) ax0[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) ax1[0].scatter(y_pred, (y_pred - y_test), s=8) ax1[0].set_ylabel('Residual') ax1[0].set_xlabel('Predicted target') ax1[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) regr_trans = TransformedTargetRegressor( regressor=RidgeCV(), transformer=QuantileTransformer(n_quantiles=900, output_distribution='normal')) regr_trans.fit(X_train, y_train) y_pred = regr_trans.predict(X_test) ax0[1].scatter(y_pred, y_test, s=8) ax0[1].plot([0, 7e5], [0, 7e5], '--k') ax0[1].set_ylabel('True target') ax0[1].set_xlabel('Predicted target') ax0[1].text(s='Ridge regression \n with target transformation', x=-5e4, y=8e5, fontsize=12, multialignment='center') ax0[1].text(3e4, 64e4, r'$R^2$=%.2f, MAE=%.2f' % ( r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) ax0[1].set_xlim([0, 7e5]) ax0[1].set_ylim([0, 7e5]) ax0[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) ax1[1].scatter(y_pred, (y_pred - y_test), s=8) ax1[1].set_ylabel('Residual') ax1[1].set_xlabel('Predicted target') ax1[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) f.suptitle("Ames housing data: selling price", y=0.035) plt.show() .. image:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_004.png :alt: Ames housing data: selling price :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.802 seconds) .. _sphx_glr_download_auto_examples_compose_plot_transformed_target.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/0.24.X?urlpath=lab/tree/notebooks/auto_examples/compose/plot_transformed_target.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_transformed_target.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_transformed_target.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_