{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression with Hundred Hammers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will explain how to use the HundredHammers library to perfom a basic model selection and hyperparameter optimization for a classification problem. \n", "\n", "To do this, we will use one of the example datasets available in the scikit-learn library." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:36:44.106350300Z", "start_time": "2023-09-22T19:36:42.905977300Z" } }, "outputs": [], "source": [ "import logging\n", "\n", "import hundred_hammers as hh\n", "from hundred_hammers.model_zoo import (\n", " DummyRegressor,\n", " Ridge,\n", " DecisionTreeRegressor,\n", " KNeighborsRegressor,\n", ")\n", "\n", "from sklearn.datasets import load_diabetes\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we store the data in the X (input) and y (target) variables." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:36:44.122081400Z", "start_time": "2023-09-22T19:36:44.107350900Z" } }, "outputs": [], "source": [ "data = load_diabetes()\n", "X = data.data\n", "y = data.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are going to first train some models with their default configuration. If you don't specify the models that you want to use, some regression models will be chosen for you.\n", "\n", "To see which models are chosen by default, you can check the ```DEFAULT_REGRESSION_MODELS``` variable" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:36:44.127474400Z", "start_time": "2023-09-22T19:36:44.123081800Z" } }, "outputs": [ { "data": { "text/plain": [ "[('Dummy Mean', DummyRegressor(), {}),\n", " ('Dummy Median', DummyRegressor(strategy='median'), {}),\n", " ('Linear Regression', LinearRegression(), {}),\n", " ('Decision Tree', DecisionTreeRegressor(), {}),\n", " ('SVR', SVR(), {}),\n", " ('Linear SVR', LinearSVR(), {}),\n", " ('Ridge', Ridge(), {}),\n", " ('Passive Aggressive', PassiveAggressiveRegressor(), {}),\n", " ('KNN', KNeighborsRegressor(), {}),\n", " ('Neural Network Regressor', MLPRegressor(), {}),\n", " ('Gaussian Process', GaussianProcessRegressor(), {}),\n", " ('Random Forest', RandomForestRegressor(), {}),\n", " ('AdaBoost', AdaBoostRegressor(), {}),\n", " ('Gradient Boosting', GradientBoostingRegressor(), {})]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hh.model_zoo.DEFAULT_REGRESSION_MODELS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that it is composed of a list of tuples. Each tuple contains the name we give to the regressors, an instance of the class that implements the regression model and a grid of hyperparameters (which now is empty, but will be explained later).\n", "\n", "Those are the models that we are going to use now." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation with default models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First create the HundredHammersRegressor object" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:36:44.130563800Z", "start_time": "2023-09-22T19:36:44.129062700Z" } }, "outputs": [], "source": [ "hh_models = hh.HundredHammersRegressor(show_progress_bar=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then evaluate the models. Apart from the actual data (the variables X and y), you can pass other parameters. ```optim_hyper``` checks whether we want to optimize the hyperparameters of the models and n_grid_points controls how many values from each hyperparameter to check in the optimization.\n", "\n", "Since we don't want to optimize the hyperparameters, optim_hyper will stay as false." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:37:09.814714700Z", "start_time": "2023-09-22T19:36:44.131564Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Evaluating models...: 100%|██████████| 14/14 [00:52<00:00, 3.77s/it]\n" ] } ], "source": [ "# configure the logger\n", "hh.hh_logger.setLevel(logging.WARNING)\n", "\n", "# Evaluate the models and store the results in a variable\n", "df_results = hh_models.evaluate(X, y, optim_hyper=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice the line above the evaluation of the models. This configures the logger to only show warnings (of which there should be none). The setting you most likely would want to use in an interactive enviroment would be ```logging.INFO```, since you get information about each model in \"real time\". \n", "\n", "If you want to see more detailed information, you can set the level to ```logging.DEBUG```. It outputs a lot of information, but it might be useful if you encounter a bug.\n", "\n", "For the purposes of this notebook, it will be kept to ```logging.WARNING``` but you are welcome to change it if you are running this notebook locally." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now show the results of our execution" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-09-22T19:37:09.830188Z", "start_time": "2023-09-22T19:37:09.815715300Z" } }, "outputs": [ { "data": { "text/html": [ "
| \n", " | Model | \n", "Avg R2 (Validation Train) | \n", "Std R2 (Validation Train) | \n", "Avg R2 (Validation Test) | \n", "Std R2 (Validation Test) | \n", "Avg R2 (Train) | \n", "Std R2 (Train) | \n", "Avg R2 (Test) | \n", "Std R2 (Test) | \n", "Avg MSE (Validation Train) | \n", "... | \n", "Avg MSE (Test) | \n", "Std MSE (Test) | \n", "Avg MAE (Validation Train) | \n", "Std MAE (Validation Train) | \n", "Avg MAE (Validation Test) | \n", "Std MAE (Validation Test) | \n", "Avg MAE (Train) | \n", "Std MAE (Train) | \n", "Avg MAE (Test) | \n", "Std MAE (Test) | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "Dummy Mean | \n", "0.000000 | \n", "0.000000 | \n", "-0.023206 | \n", "0.033628 | \n", "0.000000 | \n", "0.000000e+00 | \n", "-0.001337 | \n", "0.000000e+00 | \n", "6125.118931 | \n", "... | \n", "5134.783503 | \n", "0.000000e+00 | \n", "67.301695 | \n", "1.188564 | \n", "67.615074 | \n", "3.943941 | \n", "67.339534 | \n", "1.421085e-14 | \n", "59.227456 | \n", "7.105427e-15 | \n", "
| 1 | \n", "Dummy Median | \n", "-0.027476 | \n", "0.007274 | \n", "-0.050684 | \n", "0.071611 | \n", "-0.025922 | \n", "0.000000e+00 | \n", "-0.045202 | \n", "0.000000e+00 | \n", "6293.183322 | \n", "... | \n", "5359.719101 | \n", "9.094947e-13 | \n", "66.517990 | \n", "1.217067 | \n", "66.993618 | \n", "4.993507 | \n", "66.566572 | \n", "0.000000e+00 | \n", "59.044944 | \n", "0.000000e+00 | \n", "
| 2 | \n", "Linear Regression | \n", "0.556522 | \n", "0.015235 | \n", "0.514558 | \n", "0.067539 | \n", "0.553925 | \n", "0.000000e+00 | \n", "0.332233 | \n", "0.000000e+00 | \n", "2715.353061 | \n", "... | \n", "3424.259334 | \n", "0.000000e+00 | \n", "42.404995 | \n", "0.829168 | \n", "43.952508 | \n", "3.145022 | \n", "42.593344 | \n", "0.000000e+00 | \n", "46.173585 | \n", "0.000000e+00 | \n", "
| 3 | \n", "Decision Tree | \n", "1.000000 | \n", "0.000000 | \n", "-0.046363 | \n", "0.200661 | \n", "1.000000 | \n", "0.000000e+00 | \n", "-0.452581 | \n", "7.741044e-02 | \n", "0.000000 | \n", "... | \n", "7448.729213 | \n", "3.969551e+02 | \n", "0.000000 | \n", "0.000000 | \n", "61.592362 | \n", "5.524460 | \n", "0.000000 | \n", "0.000000e+00 | \n", "72.787640 | \n", "2.030020e+00 | \n", "
| 4 | \n", "SVR | \n", "0.159545 | \n", "0.009264 | \n", "0.126855 | \n", "0.062307 | \n", "0.186951 | \n", "2.775558e-17 | \n", "0.128119 | \n", "0.000000e+00 | \n", "5147.222487 | \n", "... | \n", "4470.939683 | \n", "0.000000e+00 | \n", "59.987041 | \n", "1.025494 | \n", "61.080993 | \n", "4.554385 | \n", "58.932403 | \n", "0.000000e+00 | \n", "53.268617 | \n", "0.000000e+00 | \n", "
| 5 | \n", "Linear SVR | \n", "-0.480290 | \n", "0.021770 | \n", "-0.499608 | \n", "0.177843 | \n", "-0.380897 | \n", "2.104066e-03 | \n", "-0.515761 | \n", "2.670427e-03 | \n", "9068.418326 | \n", "... | \n", "7772.710116 | \n", "1.369376e+01 | \n", "72.936544 | \n", "1.610525 | \n", "73.117766 | \n", "7.733969 | \n", "71.098948 | \n", "3.656171e-02 | \n", "67.626756 | \n", "6.017854e-02 | \n", "
| 6 | \n", "Ridge | \n", "0.445599 | \n", "0.013815 | \n", "0.419461 | \n", "0.046400 | \n", "0.465084 | \n", "0.000000e+00 | \n", "0.340980 | \n", "5.551115e-17 | \n", "3394.642031 | \n", "... | \n", "3379.406308 | \n", "0.000000e+00 | \n", "49.381703 | \n", "0.756940 | \n", "50.079365 | \n", "3.114845 | \n", "48.381085 | \n", "7.105427e-15 | \n", "46.566795 | \n", "0.000000e+00 | \n", "
| 7 | \n", "Passive Aggressive | \n", "0.506681 | \n", "0.026801 | \n", "0.471776 | \n", "0.060789 | \n", "0.517667 | \n", "9.373723e-03 | \n", "0.358801 | \n", "5.003282e-03 | \n", "3020.352370 | \n", "... | \n", "3288.020684 | \n", "2.565646e+01 | \n", "45.207007 | \n", "1.403516 | \n", "46.656725 | \n", "3.233495 | \n", "44.703236 | \n", "6.186625e-01 | \n", "45.272923 | \n", "2.780478e-01 | \n", "
| 8 | \n", "KNN | \n", "0.615279 | \n", "0.021038 | \n", "0.410247 | \n", "0.083013 | \n", "0.618820 | \n", "0.000000e+00 | \n", "0.172488 | \n", "0.000000e+00 | \n", "2354.986676 | \n", "... | \n", "4243.422022 | \n", "0.000000e+00 | \n", "37.785156 | \n", "1.130082 | \n", "46.627299 | \n", "3.687450 | \n", "37.339377 | \n", "0.000000e+00 | \n", "49.492135 | \n", "0.000000e+00 | \n", "
| 9 | \n", "Neural Network Regressor | \n", "-2.925185 | \n", "0.119621 | \n", "-2.988113 | \n", "0.350351 | \n", "-2.917495 | \n", "9.059604e-02 | \n", "-3.644418 | \n", "1.081171e-01 | \n", "24040.326288 | \n", "... | \n", "23816.235196 | \n", "5.544168e+02 | \n", "134.976709 | \n", "3.010464 | \n", "135.014455 | \n", "9.449213 | \n", "134.932019 | \n", "1.912264e+00 | \n", "137.552821 | \n", "1.910063e+00 | \n", "
| 10 | \n", "Gaussian Process | \n", "0.995085 | \n", "0.001516 | \n", "-13.683252 | \n", "6.935421 | \n", "0.984183 | \n", "1.110223e-16 | \n", "-9.864145 | \n", "0.000000e+00 | \n", "30.042210 | \n", "... | \n", "55710.543741 | \n", "7.275958e-12 | \n", "2.979676 | \n", "0.408915 | \n", "182.608347 | \n", "32.059452 | \n", "5.759758 | \n", "0.000000e+00 | \n", "144.352684 | \n", "2.842171e-14 | \n", "
| 11 | \n", "Random Forest | \n", "0.925400 | \n", "0.003492 | \n", "0.451311 | \n", "0.081209 | \n", "0.924440 | \n", "1.996863e-03 | \n", "0.261153 | \n", "1.513057e-02 | \n", "456.784823 | \n", "... | \n", "3788.750941 | \n", "7.758846e+01 | \n", "17.203293 | \n", "0.450834 | \n", "46.492412 | \n", "3.162357 | \n", "17.225159 | \n", "2.502492e-01 | \n", "48.277247 | \n", "5.898557e-01 | \n", "
| 12 | \n", "AdaBoost | \n", "0.687707 | \n", "0.017497 | \n", "0.460084 | \n", "0.072768 | \n", "0.662306 | \n", "5.843726e-03 | \n", "0.279929 | \n", "1.969705e-02 | \n", "1911.433685 | \n", "... | \n", "3692.470967 | \n", "1.010051e+02 | \n", "38.272124 | \n", "1.042228 | \n", "46.951305 | \n", "3.079916 | \n", "39.829220 | \n", "4.565226e-01 | \n", "47.444248 | \n", "5.013379e-01 | \n", "
| 13 | \n", "Gradient Boosting | \n", "0.889395 | \n", "0.008559 | \n", "0.441646 | \n", "0.077073 | \n", "0.857853 | \n", "1.110223e-16 | \n", "0.208258 | \n", "1.787313e-03 | \n", "677.065400 | \n", "... | \n", "4059.994938 | \n", "9.165209e+00 | \n", "20.760164 | \n", "0.796809 | \n", "46.463974 | \n", "3.126900 | \n", "23.559587 | \n", "3.552714e-15 | \n", "49.229710 | \n", "1.154214e-01 | \n", "
14 rows × 25 columns
\n", "| \n", " | Model | \n", "Avg R2 (Validation Train) | \n", "Std R2 (Validation Train) | \n", "Avg R2 (Validation Test) | \n", "Std R2 (Validation Test) | \n", "Avg R2 (Train) | \n", "Std R2 (Train) | \n", "Avg R2 (Test) | \n", "Std R2 (Test) | \n", "Avg MSE (Validation Train) | \n", "... | \n", "Avg MSE (Test) | \n", "Std MSE (Test) | \n", "Avg MAE (Validation Train) | \n", "Std MAE (Validation Train) | \n", "Avg MAE (Validation Test) | \n", "Std MAE (Validation Test) | \n", "Avg MAE (Train) | \n", "Std MAE (Train) | \n", "Avg MAE (Test) | \n", "Std MAE (Test) | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "Dummy | \n", "0.000000 | \n", "0.000000 | \n", "-0.023206 | \n", "0.033628 | \n", "0.000000 | \n", "0.000000e+00 | \n", "-0.001337 | \n", "0.000000e+00 | \n", "6125.118931 | \n", "... | \n", "5134.783503 | \n", "0.000000e+00 | \n", "67.301695 | \n", "1.188564 | \n", "67.615074 | \n", "3.943941 | \n", "67.339534 | \n", "1.421085e-14 | \n", "59.227456 | \n", "7.105427e-15 | \n", "
| 1 | \n", "Ridge | \n", "0.553573 | \n", "0.015243 | \n", "0.516564 | \n", "0.064871 | \n", "0.551601 | \n", "1.110223e-16 | \n", "0.333235 | \n", "0.000000e+00 | \n", "2733.434135 | \n", "... | \n", "3419.120423 | \n", "0.000000e+00 | \n", "42.664323 | \n", "0.836894 | \n", "43.975160 | \n", "3.176279 | \n", "42.767356 | \n", "0.000000e+00 | \n", "46.036471 | \n", "0.000000e+00 | \n", "
| 2 | \n", "Decision Tree | \n", "0.705049 | \n", "0.025423 | \n", "0.227255 | \n", "0.138329 | \n", "0.663131 | \n", "2.495672e-03 | \n", "-0.061039 | \n", "1.694207e-02 | \n", "1806.214920 | \n", "... | \n", "5440.928933 | \n", "8.687768e+01 | \n", "29.342648 | \n", "1.290889 | \n", "52.845684 | \n", "4.971056 | \n", "31.059490 | \n", "3.552714e-15 | \n", "57.152247 | \n", "7.021638e-01 | \n", "
| 3 | \n", "KNN | \n", "0.465824 | \n", "0.017405 | \n", "0.436415 | \n", "0.061569 | \n", "0.485286 | \n", "5.551115e-17 | \n", "0.315831 | \n", "5.551115e-17 | \n", "3270.483891 | \n", "... | \n", "3508.367072 | \n", "4.547474e-13 | \n", "48.434238 | \n", "0.888525 | \n", "49.248325 | \n", "2.949155 | \n", "47.282775 | \n", "0.000000e+00 | \n", "47.117041 | \n", "7.105427e-15 | \n", "
4 rows × 25 columns
\n", "| \n", " | Model | \n", "Avg R2 (Validation Train) | \n", "Std R2 (Validation Train) | \n", "Avg R2 (Validation Test) | \n", "Std R2 (Validation Test) | \n", "Avg R2 (Train) | \n", "Std R2 (Train) | \n", "Avg R2 (Test) | \n", "Std R2 (Test) | \n", "Avg MSE (Validation Train) | \n", "... | \n", "Avg MSE (Test) | \n", "Std MSE (Test) | \n", "Avg MAE (Validation Train) | \n", "Std MAE (Validation Train) | \n", "Avg MAE (Validation Test) | \n", "Std MAE (Validation Test) | \n", "Avg MAE (Train) | \n", "Std MAE (Train) | \n", "Avg MAE (Test) | \n", "Std MAE (Test) | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "Dummy | \n", "-0.027476 | \n", "0.007274 | \n", "-0.050684 | \n", "0.071611 | \n", "-0.025922 | \n", "0.000000e+00 | \n", "-0.045202 | \n", "0.0 | \n", "6293.183322 | \n", "... | \n", "5359.719101 | \n", "9.094947e-13 | \n", "66.517990 | \n", "1.217067 | \n", "66.993618 | \n", "4.993507 | \n", "66.566572 | \n", "0.0 | \n", "59.044944 | \n", "0.000000e+00 | \n", "
| 1 | \n", "Ridge | \n", "0.555347 | \n", "0.015218 | \n", "0.515900 | \n", "0.066570 | \n", "0.553033 | \n", "1.110223e-16 | \n", "0.329983 | \n", "0.0 | \n", "2722.561310 | \n", "... | \n", "3435.796416 | \n", "4.547474e-13 | \n", "42.494128 | \n", "0.829984 | \n", "43.926355 | \n", "3.180674 | \n", "42.664912 | \n", "0.0 | \n", "46.170390 | \n", "0.000000e+00 | \n", "
| 2 | \n", "Decision Tree | \n", "0.473673 | \n", "0.020150 | \n", "0.373744 | \n", "0.120035 | \n", "0.477605 | \n", "5.551115e-17 | \n", "0.020330 | \n", "0.0 | \n", "3222.626891 | \n", "... | \n", "5023.676966 | \n", "0.000000e+00 | \n", "43.848744 | \n", "0.862175 | \n", "47.866137 | \n", "4.431692 | \n", "44.014164 | \n", "0.0 | \n", "54.353933 | \n", "7.105427e-15 | \n", "
| 3 | \n", "KNN | \n", "0.554573 | \n", "0.020198 | \n", "0.459955 | \n", "0.074056 | \n", "0.548160 | \n", "0.000000e+00 | \n", "0.284378 | \n", "0.0 | \n", "2727.046694 | \n", "... | \n", "3669.657350 | \n", "9.094947e-13 | \n", "41.663537 | \n", "1.146837 | \n", "45.521591 | \n", "3.630353 | \n", "41.737832 | \n", "0.0 | \n", "46.806946 | \n", "0.000000e+00 | \n", "
4 rows × 25 columns
\n", "