diff --git "a/visualization/.ipynb_checkpoints/ModelSelection-checkpoint.ipynb" "b/visualization/.ipynb_checkpoints/ModelSelection-checkpoint.ipynb" new file mode 100644--- /dev/null +++ "b/visualization/.ipynb_checkpoints/ModelSelection-checkpoint.ipynb" @@ -0,0 +1,1902 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluate Hyperparameter Tuning Results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple wide residual network has 16 hidden layers, and it is defined as WRN-N-k\n", + "
  • k is the width factor
  • \n", + "
  • N is the number of residual blocks in wide residual network
  • \n", + "
  • to convert to residual network to wide, k should be selected greater than 1.
  • " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Python libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "import warnings\n", + "import os\n", + "import seaborn as sns\n", + "warnings.filterwarnings(\"ignore\")\n", + "plt.rcParams.update({'font.size': 12})" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "folder =\"/Users/sefika/adversarial_examples_parseval_net/src/data/GridCV/\"\n", + "pds = []\n", + "for file in os.listdir(folder):\n", + " if file !=\".DS_Store\":\n", + " results = pd.DataFrame()\n", + " results = pd.read_csv(folder+file, index_col=False, delimiter=\";\")\n", + " pds.append(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.concat(pds,ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    Unnamed: 0momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3...acc9epoch_stoppedloss10loss4loss5loss6loss7loss8loss9reg_penalty
    000.90.164.01.4739380.6369980.9790150.7469460.9722180.731239...0.727749150.00.9329181.0413250.9396651.0564631.1805501.0179231.0292650.0010
    110.90.164.00.9324590.6963350.8120010.7399650.8545540.719023...0.67364750.00.8659430.8205580.9005130.9092240.8964330.8238860.9279580.0001
    220.90.164.00.9481630.7050610.8096050.7801050.8975090.755672...0.743455100.00.8552870.8242490.9232840.8417110.9814351.0830570.8322160.0001
    330.90.164.00.9855740.7207680.9871470.7294940.9974080.741710...0.746946150.00.9297080.9021300.9013950.8717811.0644661.0050451.0270670.0001
    440.90.1128.01.1210000.5916231.3097930.5724261.3627460.504363...0.68411950.01.0877461.0153181.1715761.0520591.6722531.0469051.0320660.0100
    550.90.1128.00.9917780.7137871.4180780.5305411.0149280.699825...0.666667100.01.3151970.9877701.1135481.3148081.5417451.0268221.1344540.0100
    600.90.1128.01.0382890.7050611.1697470.6422341.2349490.631763...0.705061150.01.3337411.1032431.9059911.3753891.3818971.2389001.0349170.0100
    710.90.1128.01.4366330.6684121.4336250.6771381.4685480.656195...0.55846450.01.5384381.4066621.4228321.4930371.5755571.5865861.7376390.0010
    820.90.1128.01.2554270.6771381.1650120.7137871.1867960.685864...0.678883100.01.2370721.6476361.1663181.9153781.2822391.1658921.2122690.0010
    930.90.1128.01.0233220.7155321.2085560.6631760.9414100.727749...0.663176150.00.9873220.9872391.0003751.0149351.2707350.9892091.2098830.0010
    \n", + "

    10 rows × 27 columns

    \n", + "
    " + ], + "text/plain": [ + " Unnamed: 0 momentum learning rate batch size loss1 acc1 \\\n", + "0 0 0.9 0.1 64.0 1.473938 0.636998 \n", + "1 1 0.9 0.1 64.0 0.932459 0.696335 \n", + "2 2 0.9 0.1 64.0 0.948163 0.705061 \n", + "3 3 0.9 0.1 64.0 0.985574 0.720768 \n", + "4 4 0.9 0.1 128.0 1.121000 0.591623 \n", + "5 5 0.9 0.1 128.0 0.991778 0.713787 \n", + "6 0 0.9 0.1 128.0 1.038289 0.705061 \n", + "7 1 0.9 0.1 128.0 1.436633 0.668412 \n", + "8 2 0.9 0.1 128.0 1.255427 0.677138 \n", + "9 3 0.9 0.1 128.0 1.023322 0.715532 \n", + "\n", + " loss2 acc2 loss3 acc3 ... acc9 epoch_stopped \\\n", + "0 0.979015 0.746946 0.972218 0.731239 ... 0.727749 150.0 \n", + "1 0.812001 0.739965 0.854554 0.719023 ... 0.673647 50.0 \n", + "2 0.809605 0.780105 0.897509 0.755672 ... 0.743455 100.0 \n", + "3 0.987147 0.729494 0.997408 0.741710 ... 0.746946 150.0 \n", + "4 1.309793 0.572426 1.362746 0.504363 ... 0.684119 50.0 \n", + "5 1.418078 0.530541 1.014928 0.699825 ... 0.666667 100.0 \n", + "6 1.169747 0.642234 1.234949 0.631763 ... 0.705061 150.0 \n", + "7 1.433625 0.677138 1.468548 0.656195 ... 0.558464 50.0 \n", + "8 1.165012 0.713787 1.186796 0.685864 ... 0.678883 100.0 \n", + "9 1.208556 0.663176 0.941410 0.727749 ... 0.663176 150.0 \n", + "\n", + " loss10 loss4 loss5 loss6 loss7 loss8 loss9 \\\n", + "0 0.932918 1.041325 0.939665 1.056463 1.180550 1.017923 1.029265 \n", + "1 0.865943 0.820558 0.900513 0.909224 0.896433 0.823886 0.927958 \n", + "2 0.855287 0.824249 0.923284 0.841711 0.981435 1.083057 0.832216 \n", + "3 0.929708 0.902130 0.901395 0.871781 1.064466 1.005045 1.027067 \n", + "4 1.087746 1.015318 1.171576 1.052059 1.672253 1.046905 1.032066 \n", + "5 1.315197 0.987770 1.113548 1.314808 1.541745 1.026822 1.134454 \n", + "6 1.333741 1.103243 1.905991 1.375389 1.381897 1.238900 1.034917 \n", + "7 1.538438 1.406662 1.422832 1.493037 1.575557 1.586586 1.737639 \n", + "8 1.237072 1.647636 1.166318 1.915378 1.282239 1.165892 1.212269 \n", + "9 0.987322 0.987239 1.000375 1.014935 1.270735 0.989209 1.209883 \n", + "\n", + " reg_penalty \n", + "0 0.0010 \n", + "1 0.0001 \n", + "2 0.0001 \n", + "3 0.0001 \n", + "4 0.0100 \n", + "5 0.0100 \n", + "6 0.0100 \n", + "7 0.0010 \n", + "8 0.0010 \n", + "9 0.0010 \n", + "\n", + "[10 rows x 27 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple Residual Network\n", + "Simple Residual Netwok, witdh factor (k) = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "del data[data.columns[0]]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 40 entries, 0 to 39\n", + "Data columns (total 26 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 momentum 40 non-null float64\n", + " 1 learning rate 40 non-null float64\n", + " 2 batch size 40 non-null float64\n", + " 3 loss1 40 non-null float64\n", + " 4 acc1 40 non-null float64\n", + " 5 loss2 40 non-null float64\n", + " 6 acc2 40 non-null float64\n", + " 7 loss3 40 non-null float64\n", + " 8 acc3 40 non-null float64\n", + " 9 widing factor 40 non-null float64\n", + " 10 acc10 40 non-null float64\n", + " 11 acc4 40 non-null float64\n", + " 12 acc5 40 non-null float64\n", + " 13 acc6 40 non-null float64\n", + " 14 acc7 40 non-null float64\n", + " 15 acc8 40 non-null float64\n", + " 16 acc9 40 non-null float64\n", + " 17 epoch_stopped 40 non-null float64\n", + " 18 loss10 40 non-null float64\n", + " 19 loss4 40 non-null float64\n", + " 20 loss5 40 non-null float64\n", + " 21 loss6 40 non-null float64\n", + " 22 loss7 40 non-null float64\n", + " 23 loss8 40 non-null float64\n", + " 24 loss9 40 non-null float64\n", + " 25 reg_penalty 40 non-null float64\n", + "dtypes: float64(26)\n", + "memory usage: 8.2 KB\n" + ] + } + ], + "source": [ + "data.drop_duplicates()\n", + "data.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3widing factor...acc9epoch_stoppedloss10loss4loss5loss6loss7loss8loss9reg_penalty
    00.90.164.01.4739380.6369980.9790150.7469460.9722180.7312391.0...0.727749150.00.9329181.0413250.9396651.0564631.1805501.0179231.0292650.0010
    10.90.164.00.9324590.6963350.8120010.7399650.8545540.7190231.0...0.67364750.00.8659430.8205580.9005130.9092240.8964330.8238860.9279580.0001
    20.90.164.00.9481630.7050610.8096050.7801050.8975090.7556721.0...0.743455100.00.8552870.8242490.9232840.8417110.9814351.0830570.8322160.0001
    30.90.164.00.9855740.7207680.9871470.7294940.9974080.7417101.0...0.746946150.00.9297080.9021300.9013950.8717811.0644661.0050451.0270670.0001
    40.90.1128.01.1210000.5916231.3097930.5724261.3627460.5043631.0...0.68411950.01.0877461.0153181.1715761.0520591.6722531.0469051.0320660.0100
    \n", + "

    5 rows × 26 columns

    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate batch size loss1 acc1 loss2 \\\n", + "0 0.9 0.1 64.0 1.473938 0.636998 0.979015 \n", + "1 0.9 0.1 64.0 0.932459 0.696335 0.812001 \n", + "2 0.9 0.1 64.0 0.948163 0.705061 0.809605 \n", + "3 0.9 0.1 64.0 0.985574 0.720768 0.987147 \n", + "4 0.9 0.1 128.0 1.121000 0.591623 1.309793 \n", + "\n", + " acc2 loss3 acc3 widing factor ... acc9 epoch_stopped \\\n", + "0 0.746946 0.972218 0.731239 1.0 ... 0.727749 150.0 \n", + "1 0.739965 0.854554 0.719023 1.0 ... 0.673647 50.0 \n", + "2 0.780105 0.897509 0.755672 1.0 ... 0.743455 100.0 \n", + "3 0.729494 0.997408 0.741710 1.0 ... 0.746946 150.0 \n", + "4 0.572426 1.362746 0.504363 1.0 ... 0.684119 50.0 \n", + "\n", + " loss10 loss4 loss5 loss6 loss7 loss8 loss9 \\\n", + "0 0.932918 1.041325 0.939665 1.056463 1.180550 1.017923 1.029265 \n", + "1 0.865943 0.820558 0.900513 0.909224 0.896433 0.823886 0.927958 \n", + "2 0.855287 0.824249 0.923284 0.841711 0.981435 1.083057 0.832216 \n", + "3 0.929708 0.902130 0.901395 0.871781 1.064466 1.005045 1.027067 \n", + "4 1.087746 1.015318 1.171576 1.052059 1.672253 1.046905 1.032066 \n", + "\n", + " reg_penalty \n", + "0 0.0010 \n", + "1 0.0001 \n", + "2 0.0001 \n", + "3 0.0001 \n", + "4 0.0100 \n", + "\n", + "[5 rows x 26 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation of the hyperparameter tuning results" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "data[\"loss_mean\"] = (data[\"loss1\"]+data[\"loss2\"]+data[\"loss3\"]+data[\"loss4\"]+data[\"loss5\"]+data[\"loss6\"]+data[\"loss7\"]+data[\"loss8\"]+data[\"loss9\"]+data[\"loss10\"])/10" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "data[\"acc_mean\"] = (data[\"acc1\"]+data[\"acc2\"]+data[\"acc3\"]+data[\"acc4\"]+data[\"acc5\"]+data[\"acc6\"]+data[\"acc7\"]+data[\"acc8\"]+data[\"acc9\"]+data[\"acc10\"])/10" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "data['epoch'] = data['epoch_stopped']\n", + "data['weight_decay'] = data['reg_penalty']" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "column_list = [\"momentum\", \"learning rate\", \"epoch\",\"batch size\",\"weight_decay\",\"loss_mean\", \"acc_mean\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\begin{tabular}{lrrrrrrr}\n", + "\\toprule\n", + "{} & momentum & learning rate & epoch & batch size & weight\\_decay & loss\\_mean & acc\\_mean \\\\\n", + "\\midrule\n", + "1 & 0.9 & 0.10 & 50.0 & 64.0 & 0.0001 & 0.874353 & 0.707679 \\\\\n", + "2 & 0.9 & 0.10 & 100.0 & 64.0 & 0.0001 & 0.899651 & 0.732461 \\\\\n", + "27 & 0.9 & 0.01 & 150.0 & 64.0 & 0.0001 & 0.937435 & 0.682373 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + } + ], + "source": [ + "print(data.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3).to_latex())" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "data[\"loss_na\"] = data.loc[:,[\"loss1\",\"loss2\", \"loss3\", \"loss4\",\"loss5\",\"loss6\", \"loss7\",\"loss8\", \"loss9\",\"loss10\"]].isnull().sum(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3widing factor...loss6loss7loss8loss9reg_penaltyloss_meanacc_meanepochweight_decayloss_na
    00.90.164.01.4739380.6369980.9790150.7469460.9722180.7312391.0...1.0564631.1805501.0179231.0292650.00101.0623280.724782150.00.00100
    10.90.164.00.9324590.6963350.8120010.7399650.8545540.7190231.0...0.9092240.8964330.8238860.9279580.00010.8743530.70767950.00.00010
    20.90.164.00.9481630.7050610.8096050.7801050.8975090.7556721.0...0.8417110.9814351.0830570.8322160.00010.8996510.732461100.00.00010
    \n", + "

    3 rows × 31 columns

    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate batch size loss1 acc1 loss2 \\\n", + "0 0.9 0.1 64.0 1.473938 0.636998 0.979015 \n", + "1 0.9 0.1 64.0 0.932459 0.696335 0.812001 \n", + "2 0.9 0.1 64.0 0.948163 0.705061 0.809605 \n", + "\n", + " acc2 loss3 acc3 widing factor ... loss6 loss7 \\\n", + "0 0.746946 0.972218 0.731239 1.0 ... 1.056463 1.180550 \n", + "1 0.739965 0.854554 0.719023 1.0 ... 0.909224 0.896433 \n", + "2 0.780105 0.897509 0.755672 1.0 ... 0.841711 0.981435 \n", + "\n", + " loss8 loss9 reg_penalty loss_mean acc_mean epoch weight_decay \\\n", + "0 1.017923 1.029265 0.0010 1.062328 0.724782 150.0 0.0010 \n", + "1 0.823886 0.927958 0.0001 0.874353 0.707679 50.0 0.0001 \n", + "2 1.083057 0.832216 0.0001 0.899651 0.732461 100.0 0.0001 \n", + "\n", + " loss_na \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "\n", + "[3 rows x 31 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "loss = [data[\"loss1\"][1],data[\"loss2\"][1],data[\"loss3\"][1],data[\"loss4\"][1],data[\"loss5\"][1],data[\"loss6\"][1],\n", + " data[\"loss7\"][1],data[\"loss8\"][1],data[\"loss9\"][1],data[\"loss10\"][1],data[\"loss1\"][2],\n", + " data[\"loss2\"][2],data[\"loss3\"][2],data[\"loss4\"][2],data[\"loss5\"][2],data[\"loss6\"][2],\n", + " data[\"loss7\"][2],data[\"loss8\"][2],data[\"loss9\"][2],data[\"loss10\"][2],\n", + " data[\"loss1\"][26],data[\"loss2\"][26],data[\"loss3\"][26],data[\"loss4\"][26],\n", + " data[\"loss5\"][26],data[\"loss6\"][26],data[\"loss7\"][26],data[\"loss8\"][26],\n", + " data[\"loss9\"][26],data[\"loss10\"][26]]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "boxplot_data = {'NeuralNetwork': [\"Model1\", \"Model1\", \"Model1\",\"Model1\", \"Model1\", \"Model1\",\"Model1\", \"Model1\", \"Model1\",\"Model1\",\n", + " \"Model2\", \"Model2\", \"Model2\",\"Model2\", \"Model2\", \"Model2\",\"Model2\", \"Model2\", \"Model2\",\"Model2\",\n", + " \"Model3\", \"Model3\", \"Model3\",\"Model3\", \"Model3\", \"Model3\",\"Model3\", \"Model3\", \"Model3\",\"Model3\"],\n", + " 'Loss': loss}\n", + "boxplot_df = pd.DataFrame(data=boxplot_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BoxPlot" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Boxplot of the Best 3 Models')" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEdCAYAAAD0NOuvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjL0lEQVR4nO3de5xcdX3/8dc795AQruGSSIgsN4kUakO1IqDYilL9SYt4gwiKYEnx1lShGiRykQoGL5WgsVBuKt6AH9YCttWgURsNGKzLJWQhIbIRloTEbEgIJJ/+8f2OTIbZ3TnJzszuzvv5eMxj53zP93znc+bMzud8zzlzvooIzMzMihjW7ADMzGzwcfIwM7PCnDzMzKwwJw8zMyvMycPMzApz8jAzs8KcPKxpJIWk05r4+odL+qWkTZKWF1x2uaTZdQrNCpA0R9KygstcJ+m/6hVTK3DyaCH5HybKHusk/ULSic2OrRaSXpPjntpPTV4O/AE4FDiqh9ecXTSx7AhJUyu20fOSVkqaJ2nnfn6tZZLm1FDvhPw5eSon2g5Jl0ga1cdyC/I6zK0y7yN5XqEvfRs4nDxaz0+BffPjVcC9wG2S2poaVXMcBNwdEcsjoqvZwVR4K2kbTQXOytNfaFIsfwC+CLwWOASYBZwNfLaGZR8DTq+SaM4CVvRjjNZgTh6tZ3NE/D4/HgDOB0YCf1KqIGlnSV+V1JX3NBdLekOeN1rSryXdVlZ/rKTfSvpWni7tPc+Q9N+SNkp6VNKpvQUmaV9JN0tam5dZIGl6qU1S4gN4NLe/YHvbkhRAG3BRbmtOlTbOAC4G9i/rCZTXGyXpi5LWSHpC0uckDa9o44OSHszv48OSPilpRG/vQ7Ymb6PfRcSdwM3A9Iq2/0zSDyV15211i6T9y+a/RNL3co9ho6RHJH0sz1uQ1//CsnWbWi2QiPhFRNwcEb+NiBURcRvwdVIy6ct/A+uBvymL6zXAfsB3KitLOl3S/ZKelfS73MMZUTZ/tKSrc6/5aUlXA6OrtPNOSUvy+75c0pWSxvUUpKRpku7Kn5cNkh6QNKOG9WtdEeFHizyA64D/KpseBfwDsAnYv6z8O8By4ATgZaS9zs3AoXn+waQvhHPz9NeAR4Bd8vRUIIBO4FTS3uolwFZgetnrBHBafi5gEbAEeA1wOPAt4GlgT2A48P/yMkcB+wC797CetbS1D7AS+Of8fHyVdsbm+StznT/Wy+/P06TkexDwDuB54L1ly88h7V3/DfBS4ETSnvjFvWyj0nv3mrKyA4B24OqyssOAbuDTpMNuh+ftthQYk+vcDvwXcGRu93XAu/K83YFHgc+VrdvwGj9HhwIPAlf2UW8B8K/AbLb93N0AfCW/P8vKyv8a2AL8E+kz9o78Hl9cVufzwJOkntihOf4/VLRzRl5uRn7vjgV+A9zYy//Cb4Bv5Pf1AOBNwJub/T87kB9ND8CPBm7s9A/zfP7S6SZ9mXcDby+rc2D+8jqxYtl7gWvLpk8nJZ2LSInlz8vmlb4AL65o4+fATWXT5cnj9Xn6sLL5o4FVwKfy9Gtynal9rGefbeWy5cDsPtqaDSyvUr4cuL2i7E7gm/n5TsAzwBsr6rwHWNvL65Xeu2fyttmUp39EWYLL2/LmimVH5+VOytP3AXN6ea1lvc2vUv93wLM5nq/SR7LhheSxb/6MtAG75hhfwYuTx0+Bb1e08WFgI2lHZ1x+P86qqLO4op3lwN9V1Dk2x71b2ftXnjzWAWc0639zMD582Kr1LCLtiR5J+ge+CLhe0gl5/mH5708qlvsJMK00ERHXA/8fuAC4ICJ+WeW1flEx/bOy9itNA1ZHxP1lr/FsjndaD8v0pD/b6s2SiunHgb3LYhgLfC8fVuqW1E360t1F0sQ+2n4vaRv9CakHOB64XVLpf/Yo4G8q2l4NjCH1hCCdI/mEpEWSPivp2O1cz5JjSJ+ZGcCbgU/VslBErAL+AziTlDwfiIh7q1Sdxos/d3eT1qktP0aTdkLKLSw9ye/r/sCVFe/NHbnKgT2E+TngX/PhzTmSXlHLurWyWo692tCyMSLKr3BZIun1wCeBu3pZTqQ9tzQhjSd9kWwhHWKohfqYX+0Wz+qhvC/92VZPNld5zdKXe+nvKaRDSZXW9NH242Xbaamk9aQvzdeSeiHDgBtJh9UqrQaIiH+TdCfwRtIhqzsk3RoR23V5dEQ8mp+2S9oC3CTp8ojYUMPi84FrSOv9pd5epmJaZeXqoU650vv+YeDHVeb/ruqLRlws6euk9+p4UtK9PCJ8OXYP3PMwSIeydsrP2/Pfyr3UY8rmAVxNShzHA6dJemeVdl9VMf0XwAM9xNAO7Cnpjz0TSaOBPy973dKX9XB6V0tbtdpcw+v1FMMm4ICIWFblsaVge8/nv6XttJjUK+mo0vbTpYUiYlVE/FtEvIe053+qpAk7uG6QvjuGkS62qMWdpENe+5POLVTTDhxXUXYs6bDVI6TDbJuBoyvqvLr0JCKeIJ2jOqSH931TTwFGxCMRMS8i3kbqVZ1T47q1JPc8Ws8oSfvk5+NIh0ROAC4EiIgOSd8B5kn6AOmE7znAy4F3Ayj9sO8U4FURsUTSJ4CvSlpUtncKcKakB0lfdKeRksdHeojrR8AvgW9I+nvSMegLSIcsrs51VpDO05yYr+x6NiLWbWdbtXoU2EfSXwAPA89ExDN9LRQR3ZI+A3xGEsB/kv7fDgf+NCLO66OJ3fN2Gkb6wr2cdKK4dMjmM6R1vEnSF4Eu0vmSk4AvRsQjkr5MOlz0EGnd/5b0xbq+bN2OljSFdB5iTURsrQxE0izSCfKlpL3+6Tme2yNibV/vRX4/tko6HBgWEet7qHYZ8H1J5wO3kA7bzQHmRsRmYLOkrwCXSHoir9eZpBPnT5a180ngGklrgduA50gXfrwpIj5QZf3Gky47/l5+T3Yl9UDur6xrZZp90sWPxj1IJwmj7PEMaW/vH0n/1KV6E0jH5rtIe4uLgTfkeQeSrm75YFl9kY4pLyLtiU7N7c8gnTTdRDqJOaMinj+eMM/T+5IuSV1L2tu8m7Krs3Kdj5POLWwBFvSyrrW0tZy+T5iPJO0pr8nxzulpWdLJ4QUVZWeSzo1sIl0BtAg4p5fXK713pcdW4PfArcDLK+oeTjrv9HRex2Wkw0O75/lXkb7wN5IOZf0AmFa2/HTgnjy/xwsRSFeU3Q9sICWe3wKfAHbq471bAPxrL/PnUHaiO5edTuqdbs7b+VJgRNn8saTP5rr8mE9KOpXtnEQ65/YM6fO6hG0vlriOfMKclFi/QUocm0iJ6FvAfs3+nx3ID+U3z6zf5N8LPAocExEL+6huZoOQz3mYmVlhTh5mZlaYD1uZmVlh7nmYmVlhLXOp7p577hlTp05tdhhmZoPKPffc81REvOiOCC2TPKZOncrixYubHYaZ2aAiqeqt833YyszMCnPyMDOzwpw8zMysMCcPMzMrzMnDzMwKc/IwM7PCnDzMzKywlvmdh9mOmDdvHh0dHf3ebmdnJwCTJk3q97YB2tramDlzZl3attbm5GHWRBs3bmx2CGbbxcnDrAb12nufNWsWAHPnzq1L+2b14nMeZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlaYk4eZmRXm5GFmZoU5eZiZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5luym9mQNxgH8xroA3k5eZiZbadWHszLycPMhjwP5tX/fM7DzMwKc/IwM7PCGpY8JJ0rabGkZyVd10fdj0r6vaR1kq6VNLps3gJJmyR158dDdQ/ezMy20cieRydwCXBtb5UknQCcD7wemAocAHy6otq5ETE+Pw6pQ6xmZtaLhiWPiLglIm4DVvdR9XTgmohoj4ingYuBM+ocnpmZFTAQz3lMA+4rm74P2FvSHmVll0l6StLPJL22p4YknZ0PlS3u6uqqT7RmZi1oICaP8cC6sunS853z3/NIh7ImA/OB70tqq9ZQRMyPiOkRMX3ixIn1itfMrOUMxOTRDUwomy49Xw8QEYsiYn1EPBsR1wM/A05scIxmZi1tICaPduCIsukjgCcioqdzJQGo7lGZmdkfNfJS3RGSxgDDgeGSxkiq9gv3G4AzJR0maTdgNnBdbmNXSSeUlpV0KnAscFeDVsPMzGhsz2M2sJF0Ge5p+flsSVPy7zWmAETEncDlwI+BFflxYW5jJOly3y7gKeCDwEkR4d96mJk1UMPubRURc4A5PcweX1H3SuDKKm10AUf1d2xmZlbMQDznYWZmA5yTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlaYk4eZmRXm5GFmZoU5eZiZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlaYk4eZmRXm5GFmZoU5eZiZWWFOHmZmVpiTh5mZFdaw5CHpXEmLJT0r6bo+6n5U0u8lrZN0raTRZfN2l3SrpA2SVkh6d92DNzOzbTSy59EJXAJc21slSScA5wOvB6YCBwCfLqtyFbAZ2Bs4Fbha0rQ6xGtmZj1oWPKIiFsi4jZgdR9VTweuiYj2iHgauBg4A0DSOOBk4IKI6I6IhcDtwIy6BW5mZi8yEM95TAPuK5u+D9hb0h7AwcCWiFhaMb9qz0PS2flQ2eKurq66BWxm1moGYvIYD6wrmy4937nKvNL8nas1FBHzI2J6REyfOHFivwdqZtaqBmLy6AYmlE2Xnq+vMq80f30D4jIzs2wgJo924Iiy6SOAJyJiNbAUGCHpoIr57Q2Mz8ys5TXyUt0RksYAw4HhksZIGlGl6g3AmZIOk7QbMBu4DiAiNgC3ABdJGifpaOCtwI0NWQkzMwMa2/OYDWwkXYZ7Wn4+W9IUSd2SpgBExJ3A5cCPgRX5cWFZOzOBscCTwDeBcyLCPQ8zswaqtudfFxExB5jTw+zxFXWvBK7soZ01wEn9GJqZmRU0EM95mJnZAOfkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlZYw34kaGbWm3nz5tHR0dHsMAopxTtr1qwmR1K7trY2Zs6cucPtOHmY2YDQ0dHBkgd/w5ZBNHpC6dDNPat/09Q4ajW8H4c1cvIwswFjy0RY9zYfTa+XXb67td/a8lYyM7PCnDzMzKwwJw8zMyvMycPMzApz8jAzs8KcPMzMrDAnDzMzK8zJw8zMCnPyMDOzwpw8zMysMCcPMzMrzMnDzMwKqzl5SHqdpJfm5/tKul7StZL2qV94ZmY2EBXpecwDtuTnc4GRQADz+zsoMzMb2Irckn1yRDwmaQRwArA/sBnorEtkZmY2YBXpefxB0t7AccD9EdGdy0fWsrCk3SXdKmmDpBWS3t1DvdGSPi+pU9LTkuZJGlk2f4GkTZK68+OhAutgZmb9oEjy+BfgV8DXgaty2dHAgzUufxWpp7I3cCpwtaRpVeqdD0wHXg4cDLwCmF1R59yIGJ8fhxRYBzMz6wc1J4+I+Czwl8DREXFzLn4ceH9fy0oaB5wMXBAR3RGxELgdmFGl+luAL0XEmojoAr4EvK/WOM3MrP4KXaobEUsjogPS1VfAPhHxvzUsejCwJSKWlpXdB1TreSg/yqdfImmXsrLLJD0l6WeSXtvTi0o6W9JiSYu7uvpx8F4zsxZX5FLduyUdnZ+fB9wMfFPSJ2pYfDywrqJsHbBzlbp3AB+WNDFfBvyhXL5T/nsecAAwmXSl1/cltVV70YiYHxHTI2L6xIkTawjTzMxqUaTn8XLgf/Lzs4DXAq8C/q6GZbuBCRVlE4D1VepeCvwaWAL8HLgNeA54EiAiFkXE+oh4NiKuB34GnFhgPczMbAcVSR7DgMh7+YqIByJiJbBbDcsuBUZIOqis7AigvbJiRGyMiHMjYnJEHACsBu6JiC2VdUuLsO1hLjMzq7Miv/NYCHwZ2Be4FSAnkqf6WjAiNki6BbhI0vuBI4G3Aq+urCtpMikhrAJeCVwAnJnn7ZrL7gaeB94BHAt8pMB6mJnZDiqSPM4AZgFdwBW57FDgizUuPxO4lnT4aTVwTkS0S5oC3A8cFhGPAW3ADcBewErg/Ij4YW5jJHBJft0tpMuET4oI/9bDAJg3bx4dHR3NDqNmpVhnzZrV5EiKaWtrY+bMmc0Ow5qo5uQREauBT1SU/aDA8muAk6qUP0Y6oV6a/gkwtYc2uoCjan1Naz0dHR083P5r9hu3udmh1GTkc+n3r5uWL2pyJLVbuWFUs0OwAaDm5JF/5T2b9NuMSaTbktwIXBoRg+M/1VrCfuM287Fpjzc7jCHrivbJzQ7BBoAih60uB/6cdHXVCtK9rS4gXTX10f4PzczMBqoiyeMU4Ih8+ArgIUn3kn7s5+RhZtZCilyq29PlsL5M1sysxRRJHt8h/Zr7BEkvk/RG0g/4vl2XyMzMbMAqctjq46QT5leRTpg/TrpFyeg6xDXk1PMS0s7ONKTKpEmT+r1tX5JpZtUUuVR3M/Cp/ABA0hhgAymxWJNs3Lix2SGYWYsp0vOoxrcGqVE9995LPzCbO3du3V7DzKzcjiYPSAnEzGyHdHZ2MvwPsMt3tzY7lCFreBd0Pts/I4f3mTwkHd/LbP/U1MysBdXS87imj/mP9UcgZtbaJk2axKrRT7HubYXGqLMCdvnuVibt0T8X1vSZPCLipf3ySmZmNmQ4xZuZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlaYk4eZmRXWH7dkH1LqOeJfvZTiLY3rMRh4hEKzwc3Jo0JHRwe//s3DbB62X7NDqdnIrSMBWPTbTU2OpDajtq5sdghmtoOcPKrYPGw/Vo39WLPDGLL23XhFs0Mwsx3UsHMeknaXdKukDZJWSHp3D/VGS/q8pE5JT0uaJ2lk0XbMzKx+GnnC/CpgM7A3cCpwtaRpVeqdD0wHXg4cDLwCmL0d7ZiZWZ00JHlIGgecDFwQEd0RsRC4HZhRpfpbgC9FxJqI6AK+BLxvO9oxM7M6aVTP42BgS0QsLSu7D6jWY1B+lE+/RNIuBdsxM7M6aVTyGA+sqyhbB+xcpe4dwIclTZS0D/ChXL5TwXaQdLakxZIWd3V1bXfwZma2rUYlj25gQkXZBGB9lbqXAr8GlgA/B24DngOeLNgOETE/IqZHxPSJEydub+xmZlahUZfqLgVGSDooIh7OZUcA7ZUVI2IjcG5+IOls4J6I2CKp5nbMbPAZ3gW7fHdrs8Oo2bC16e/WXZsZRe2GdwF79E9bDUkeEbFB0i3ARZLeDxwJvBV4dWVdSZOBAFYBrwQuAM4s2o6ZDS5tbW3NDqGwjrXp7g5tewyS2Pfov/e5kT8SnAlcSzr8tBo4JyLaJU0B7gcOi4jHgDbgBmAvYCVwfkT8sK92GrcaZlYPg/F2NaVbAs2dO7fJkTRew5JHRKwBTqpS/hjpRHhp+ifA1KLtmJlZ4/iuumZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5sGgKnR2djJq6wYPWFRHo7aupLNzXLPDMLMd4J6HmZkV5p5HhUmTJrFyzSYPQ1tH+268gkmTxjQ7DDPbAe55mJlZYe552JDS2dnJhg2juKJ9crNDGbJWbhjFuM7OZodhTeaeh5mZFeaehw0pkyZNYtPmlXxs2uPNDmXIuqJ9MmMmTWp2GNZk7nmYmVlhTh5mZlaYk4eZmRXm5GFmZoU5eZiZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlZYw5KHpN0l3Sppg6QVkt7dQz1JukTS45LWSVogaVrZ/AWSNknqzo+HGrUOZmaWNLLncRWwGdgbOBW4ujwplDkFeB9wDLA78Avgxoo650bE+Pw4pI4xm5lZFQ1JHpLGAScDF0REd0QsBG4HZlSp/lJgYUQ8EhFbgJuAwxoRp5mZ1aZRIwkeDGyJiKVlZfcBx1WpezPwDkkHA48CpwN3VtS5TNI/Aw8Bn4yIBdVeVNLZwNkAU6ZM2aEVMLPBa968eXR0dPR7u6U2Z82a1e9tt7W1MXPmzH5vt780KnmMB9ZVlK0Ddq5SdxXwU1Ji2AKsBI4vm38ecD/pENg7ge9LOjIiXvTJiIj5wHyA6dOnxw6ug5nZNsaOHdvsEJqmUcmjG5hQUTYBWF+l7oXAUcB+wO+B04AfSZoWEc9ExKKyutdLehdwIvAv/R+2mQ0FA3kPfrBq1AnzpcAISQeVlR0BtFepewTwrYj4XUQ8HxHXAbvR83mPANSfwZqZWe8akjwiYgNwC3CRpHGSjgbeyouvogL4FXCKpL0lDZM0AxgJLJO0q6QTJI2RNELSqcCxwF2NWA8zM0saddgKYCZwLfAksBo4JyLaJU0hncM4LCIeAz4L7AUsAcYBy4CTI2KtpInAJcChpPMhDwInRYR/62Fm1kANSx4RsQY4qUr5Y6QT6qXpTcDf50dl3S7S+RAzM2si357EzMwKc/IwM7PCnDzMzKywRp4wN2uIlRtGcUX75GaHUZMnN40EYK8xzzU5ktqt3DCKg/quZkOck4cNKW1tbc0OoZDn8u0txkwdPHEfxOB7n63/OXlUMWrrSvbdeEWzw6jZyK1PAvDcsL2aHEltRm1dCXXadx1svyQu3RNp7ty5TY7ErBgnjwqDcY+qoyMd8mhrG9PkSGp10KB8n83sBU4eFQbbnit479XMGs9XW5mZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV5uRhZmaFOXmYmVlhTh5mZlaYk4eZmRXme1s1yLx58+jIt9/ub6V2S/e46k9tbW2D8n5fZlZfTh5DwNixY5sdgpm1GEVEs2NoiOnTp8fixYubHYYNUvXqOZbarNct6t1ztB0l6Z6ImF5Z7p6HWRO512iDlZOHWQ289262LV9tZWZmhTl5mJlZYQ1LHpJ2l3SrpA2SVkh6dw/1JOkSSY9LWidpgaRpRdsxM7P6aWTP4ypgM7A3cCpwdXlSKHMK8D7gGGB34BfAjdvRjpmZ1UlDkoekccDJwAUR0R0RC4HbgRlVqr8UWBgRj0TEFuAm4LDtaMfMzOqkUT2Pg4EtEbG0rOw+oFqP4WbgQEkHSxoJnA7cuR3tIOlsSYslLe7q6trhlTAzs6RRl+qOB9ZVlK0Ddq5SdxXwU+AhYAuwEjh+O9ohIuYD8yH9SHB7AjczsxdrVM+jG5hQUTYBWF+l7oXAUcB+wBjg08CPJO1UsB0zM6uTRvU8lgIjJB0UEQ/nsiOA9ip1jwC+FRG/y9PXSfoC6bzHAwXa2cY999zzlKQVO7ISA9yewFPNDsK2i7fd4DbUt9/+1Qobdm8rSTcDAbwfOBL4D+DVEdFeUe9C4K9IJ8a7SFdUfQWYHBFra22n1UhaXO3+MzbwedsNbq26/Rp5qe5MYCzwJPBN4JyIaJc0RVK3pCm53mdJJ8GXAGuBjwInR8Ta3tpp1EqYmVkL3VV3qGvVvZ+hwNtucGvV7efbkwwd85sdgG03b7vBrSW3n3seZmZWmHseZmZWmJOHmZkV5uQxBEiaKikk9fm7HUlnSFrYiLisNt5+g1crbzsnjyaQtFzSZkl7VpQvyR/EqQ2M5WJJ/yvpeUlzGvW6g9lA2X6S9pL0TUmdefiCn0l6ZSNee7AaKNsuv+aPJXVJ+oOk+yS9tVGv3R+cPJrnUeBdpQlJh5N+v9Joy4CPAz9owmsPZgNh+40HfgX8GWn4guuBH0ga3+A4BpuBsO0APgzsGxETgLOBmyTt24Q4touTR/PcCLynbPp04IbShKRdJN2Q90xWSJotaVieN1zS5yQ9JekR4K/LG87LXiNpVR5U6xJJw6sFERHXR8Qd+P5gRTV9++VhC66MiFURsSXfCHQUcEg9VngIafq2A4iI30TE86VJYCTpnn6DgpNH8/wPMEHSy/KH6x2ksUtK/gXYBTgAOI70YX9vnncW8GbgT4HpwNsq2r4eeB44MNd5A+l2LtZ/Btz2k3QkKXks2641ah0DZttJ+ndJm4BFwAJg8Y6sWCM5eTRXaQ/or4AHgcdzeekD/U8RsT4ilgNzeWHQq7cDX4iIlRGxBris1KCkvYE3AR+JiA0R8STweeCdDVifVjNgtp+kCTmeT0dE5bAF9mIDYttFxJtJQ0qcCNwVEVv7bxXrq1F31bXqbgR+Qho98Yay8j1Je5DldwFeAUzOzyeRxjkpn1eyP6n7u0pSqWxYRX3rHwNi+0kaC3wf+J+IuKyneraNAbHtACLiOeAOSR+W1BERtxdbleZw8miiiFgh6VHSXseZZbOeAp4jfRjvz2VTeGHvaBXbHhudUvZ8JfAssGfZ8VSrg4Gw/SSNBm7LbX+g+Fq0poGw7aoYAbRtx3JN4cNWzXcmcHxEbCgr2wJ8G7hU0s6S9gf+gReOy34b+JCkl0jaDTi/tGBErAJ+CMyVNEHSMEltko6r9uKSRkoaQ/osjJA0pqcTfFZV07af0jDN3wU2Au8ZTIc8BohmbrtDJb1J0tj8P3gacCxwd13WtA6cPJosIjoiotpJsg8CG4BHgIXAN4Br87yvAXeRbl1/L3BLxbLvIXW97weeJn3B9HQJ4NdIXz7vAj6Zn8/ooa5VaPL2ezXp5O0bgLVKQxt0Szpmh1aqRTR52wmYQxpaoot02e47IuLe7V+jxvKNEc3MrDD3PMzMrDAnDzMzK8zJw8zMCnPyMDOzwpw8zMysMCcPMzMrzMnDrI4kzZF0U981Bx6l8S0ObHYcNjA5ediQozTgzxOSxpWVvV/SgiaGVT7q3A8qym9SjQNx5XX7y7oEaFaAk4cNVSNIv9qtG9Uw9GgPXiXp6H4Nph/twHpZC3HysKHqCuAfJe1aOSPfV+g/Ja2R9JCkt5fNWyDp/WXT24w7nXsOfy/pYeDhXPZFSSuVhhO9p4bbg1wOXNLTTElvVhoWda2kn0v6k1x+I+lGfN/PtyH5uKTrJc3K8yfn+Gbm6QPzOipPnyVpWS67XdKk3tarIqbX5HV8XR/rZi3CycOGqsWkwXX+sbwwH8r6T9L9ivYi3dNrnqRpBdo+CXglcFie/hVwJGko2G8A38k3m+zJVcDB1Q4/SXoF6T5KHwD2AL4K3C5pdETMAB4D3hIR4yPictKN9F6bFz+OdD+m0o34jgV+GhEh6XjS2BNvJ91raQVwcx/rVYrpBOCbwMkR8eNe1staiJOHDWWfAj4oaWJZ2ZuB5RHxbxHxfL4R3fd48YhwvbksItZExEaAiLgpIlbn9uYCo+l9KNhNwKVU732cBXw1IhbloWWvJ93m+1U9tHU3cIzSMKnHkno1pUNix/HCXVpPBa6NiHsj4lngn4C/kDS1p/XKTgHmAydGxC97WSdrMU4eNmRFxG+Bf6fsttmkcRpemQ8JrZW0lvTFuk+BprcZ3EfSLEkPSFqX29uFNKhQb74G7C3pLRXl+wOzKuLbjzQI0YtERAfQTer5HENa305Jh7Bt8phE2cBFEdENrOaFQY5etF7ZR4BvR8T/9rE+1mJ8YsyGugtJt86em6dXAndHxF/1UH8DsFPZdLWk8sdbUefzG+cBrwfaI2KrpKdJt9zuUUQ8J+nTwMVAe9mslcClEXFpT4tWKbub1HMaFRGPS7qbdGvw3YAluU4nKTGV4h5HOiz2eFk71do+BbhG0uMR8YXe1slai3seNqRFxDLgW8CHctG/k843zMiD8IyUdJSkl+X5S4C/lbRT/o3DmS9udRs7A8+TxmQYIelTwIQaw7uRdIjrjWVlXwP+TtIrlYyT9NeSds7znwAOqGjnbuBc0rCqkM71fBBYGBFbctk3gPdKOlJp9MHPAIvyGN296SQlxg+VTsSbgZOHtYaLgHEAEbGeNHjSO0lfjL8HPkv6Egf4PLCZ9CV9PfD1Ptq+C7gDWEo6LLSJGseLz1/sF5JOtJfKFpPOe3yZNJjQMuCMssUuA2bnQ1qliwHuJiWxUvJYSOo9laaJiP8GLiCd31lFGu70nTXG+RgpgZxXfiWatTYPBmVmZoW552FmZoU5eZiZWWFOHmZmVpiTh5mZFebkYWZmhTl5mJlZYU4eZmZWmJOHmZkV9n+3yteAZmnvqAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
    " + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sns.boxplot(x=boxplot_df[\"NeuralNetwork\"], y=boxplot_df[\"Loss\"], palette=\"bright\")\n", + "plt.title(\"Boxplot of the Best 3 Models\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "best_model=pd.DataFrame()\n", + "best_model =data.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list][:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Consequently, it can be seen that increasing the widing factor results in high loss values and loss mean on the results of hyperparameter tuning.\n", + " \n", + " However, we will train the simple residual network and simple wide residual networks on the rest of the project." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize the parameters of the Best Model " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning rateepochbatch sizeweight_decayloss_meanacc_mean
    10.90.1050.064.00.00010.8743530.707679
    20.90.10100.064.00.00010.8996510.732461
    260.90.01150.064.00.00010.9374350.682373
    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate epoch batch size weight_decay loss_mean \\\n", + "1 0.9 0.10 50.0 64.0 0.0001 0.874353 \n", + "2 0.9 0.10 100.0 64.0 0.0001 0.899651 \n", + "26 0.9 0.01 150.0 64.0 0.0001 0.937435 \n", + "\n", + " acc_mean \n", + "1 0.707679 \n", + "2 0.732461 \n", + "26 0.682373 " + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_model" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAADQCAYAAACX3ND9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAY2ElEQVR4nO3de7RkZXnn8e8Pm2tDKyiCoCggRIEMKO0Vb3hBdI1CFqNjQBAvECEToy5A4goieJlIHB2NgEMiKnhDIyhewDgDIkQRm2ijrdIBhIbm1g3SdDcooX3mj9odi7IO51R3VfU5Z38/a9Vi73e/e++nDvrw1Lsvb6oKSZIktcdGGzoASZIkjZcFoCRJUstYAEqSJLWMBaAkSVLLWABKkiS1jAWgJElSy1gAap0lWTWGc7w1yRGjPk/POQ9OsseIz7Fvkp8luS7Jx5Nkgn5/0/S5NsnLu9o/kOTmcfw7kGYT89Z6nWPSvJXk0UkuTbIqySdGGY/WjwWgNrgkj5hoW1V9sqrOGec5gYOBkSZS4EzgaGC35nNgb4cmmb8O2LPZfkZX3N8AnjniGCVNwLzVP28BvwVOAo4bcSxaTxaAGookxyf5cZJrkpzS1f61JFcnWZTk6K72VUlOTfIj4DnN+geSLExyZZLtmn7vTXJcs/y9JB9KclWSxUme37RvkeTLzbnPS/KjJPP7xHhjkvckuQJ4TZKjmpgXJvlqc5znAq8G/j7JT5Ps2nwubr7H5Umesp5/q8cB86rqh9V5E/s5dJJ3r4OAL1XV76rq18B1NEVfVV1ZVbetTxxS25m3BvpbTSlvVdXqqrqCTiGoacwCUOstyQF0fg0+E9gH2DfJC5rNb6qqfYH5wNuSPLppnwv8vKqe1SSLucCVVbU38H3gqAlON6eqngm8HTi5aTsW+E1V/RfgfcC+DxPub6vqeVX1JeD8qnpGc85fAm+uqh8AFwLHV9U+VXU9cBbwV833OA44o8/fYP8m8fZ+ftAnhh2BW7rWb2na+vW7eQr9JA3IvDWyvKUZYs6GDkCzwgHN5yfN+pZ0Euv36STPP2van9C03wWsAb7adYwHgG82y1cDL5vgXOd39XlSs/w84GMAVfXzJNc8TKzndS3vleT9wKOamL/T2znJlsBzga/kD7e7bNrbr6oupfMfkanod79fvzkZp9pP0uDMW6PJW5ohLAA1DAH+Z1X9n4c0Ji8CXgo8p6ruS/I9YLNm82+rak1X9/+oP0xMvYaJ/7f5uz59+j5AMYHVXcufAQ6uqoVJjgRe1Kf/RsA9VbXPwx00yf7AR/tsuq+qntvTdgvw+K71xwO39tn3Fjr/8Zmsn6TBmbdGk7c0Q3gJWMPwHeBNza9OkuyY5LHAI+lc4rivuf/k2SM6/xXAa5tz7wH86RT32wq4LcnGwGFd7SubbVTVvcCvk7ymOX6S7N17oKq6tLn00vvpTaI09+6tTPLsdH6eHwF8vU98FwKvS7Jpkp3pjEJcNcXvJunhmbdGk7c0Q1gAar1V1b8AXwB+mORnwD/TSUQXA3OaSxvvA64cUQhnANs253kXcA2wYgr7nQT8CPgu8Kuu9i8Bxyf5SZJd6STZNydZCCyi83DG+joG+Cc6D3ZcD1wEkOTVSU4FqKpFwJeBX9D5W/7l2tGHJKcluQXYIsktSd47hJik1jBvrZNJ81azfiPwEeDIJj+N+ulkrYP8YfRampnSeTXCxlX12ybx/T9g96p6YAOHJkl9mbe0oXkPoGaDLYBLm0siAY4xiUqa5sxb2qAcAZQkSWoZ7wGUJElqGQtASZKklpk19wAeeOCBdfHFF2/oMCS1xyDvcftP5ipJY9Y3V82aEcDly5dv6BAkaVLmKknTwawpACVJkjQ1FoCSJEktYwEoSZLUMhaAkiRJLWMBKEmS1DKz5jUwG8K+x5+zoUOQBFz990ds6BAkaUZxBFCSJKllLAAlSZJaxgJQkiSpZSwAJUmSWsaHQCRJmsZOOOEEbr/9drbffntOO+20DR2OZgkLQEmSprHbb7+dpUuXbugwNMt4CViSJKllHAGUpFnOd5bObFstX8kjgCXLV/rvcoabTu8sdQRQkiSpZRwBlCRpGvv9JnMf8k9pGCwAJUmaxlbvdsCGDkGzkJeAJUmSWsYCUJIkqWUsACVJklrGAlCSJKllLAAlSZJaZmwFYJJtklyQZHWSm5IcOkG/JHl/kqVJViT5XpI9xxWnJEnSbDfOEcDTgQeA7YDDgDMnKOxeA7wJeD6wDfBD4NxxBSlJkjTbjaUATDIXOAQ4qapWVdUVwIXA4X267wxcUVU3VNUa4HPAHuOIU5IkqQ3GNQK4O7CmqhZ3tS0E+o0Afgl4cpLdk2wMvAG4uN9BkxydZEGSBcuWLRt60JI0DOYqSdPNuArALYEVPW0rgK369L0NuBy4FrifziXhd/Q7aFWdVVXzq2r+tttuO8RwJWl4zFWSpptxFYCrgHk9bfOAlX36ngw8A3gCsBlwCnBJki1GGqEkSVJLjKsAXAzMSbJbV9vewKI+ffcGzquqW6rqwar6DLA13gcoSZI0FGMpAKtqNXA+cGqSuUn2Aw6i/9O9PwZek2S7JBslORzYGLhuHLFKkiTNdnPGeK5jgbOBO4G7gGOqalGSnYBfAHtU1RLgQ8BjgZ8Cc+kUfodU1T1jjFWSJGnWGlsBWFV3Awf3aV9C5yGRteu/Bf6y+UiSJGnInApOkiSpZSwAJUmSWsYCUJIkqWUsACVJklrGAlCSJKllLAAlSZJaxgJQkiSpZSwAJUmSWsYCUJIkqWUsACVJklrGAlCSJKllBpoLOMkmwJHAPnTN3wtQVUcMLSpJkiSNzEAFIPBZYG/gG8Adww9HkiRJozZoAXggsHNV3TOCWCRJkjQGg94DuATYdBSBSJIkaTwGHQE8B/h6ko/Rcwm4qi4ZWlSSJEkamUELwP/R/PODPe0F7LL+4UiSJGnUBioAq2rnUQUiSZKk8fA9gJIkSS0zUAGYZF6SjyS5OslNSZas/Uxh322SXJBkdbPvoQ/Td5ck30yyMsnyJKcNEqckSZImNugI4BnA04FTgW2Av6LzZPBHp7Dv6cADwHbAYcCZSfbs7dS8bPq7wCXA9sDjgc8NGKckSZImMOhDIAcAT62qu5KsqaqvJ1lA58XQExaBSeYChwB7VdUq4IokFwKHAyf2dD8SuLWqPtLVds2AcUqSJGkCg44AbgSsaJZXJXkUcBvw5En22x1YU1WLu9oWAn80Agg8G7gxyUXN5d/vJfnTfgdNcnSSBUkWLFu2bKAvIknjYq6SNN0MWgAuBF7YLF9O57LumcDiCffo2JI/FI5rrQC26tP38cDrgI8DOwDfovPuwU16O1bVWVU1v6rmb7vttlP+EpI0TuYqSdPNoAXgUcCNzfLbgPuBRwFHTLLfKmBeT9s8YGWfvvcDV1TVRVX1APBh4NHAUweMVZIkSX0M+h7AG7qWlwFvmeKui4E5SXarqn9v2vYGFvXpew2w3yBxSZIkaeoGfQ1MkhyV5JIk1zRtL0jy2ofbr6pWA+cDpyaZm2Q/4CDg3D7dPwc8O8lLkzwCeDuwHPjlILFKkiSpv0EvAZ8KvBk4C9ipabsFeNcU9j0W2By4E/gicExVLUqyU5JVSXYCqKprgdcDnwR+Q6dQfHVzOViSJEnradDXwBwJPK2qlic5s2n7NVOYB7iq7gYO7tO+hM5DIt1t59MZMZQkSdKQDToC+Ag6D3QAVPPPLbvaJEmSNM0NWgB+G/hIkk2hc08g8D46L4KWJEnSDDBoAfhOOu/mWwE8ks7I3xOZ2j2AkiRJmgYGfQ3MvcDBSbaj8xDIzVV1+0gikyRJ0kgMOgK41v3AUmCjJDsk2WGIMUmSJGmEBhoBTPJSOq+AeSKQrk1F5wERSZIkTXODjgB+Cvggnfv/Nu76/NE8vZIkSZqeBn0P4GbAp6tqzSiCkSRJ0ugNOgL4UeCE5vUvkiRJmoEGHQH8KvAd4G+SLO/eUFWTzgYiSZKkDW/QAvCfgcuBr9B5EliSJEkzzKAF4M505gL+/SiCkSRJ0ugNeg/g14EXjyIQSZIkjcegI4CbAhcmuRy4o3tDVR0xtKgkSZI0MoMWgIuajyRJkmaoQecCPmWyPklOrKq/W/eQJEmSNErrOhfww3n3CI4pSZKkIRlFAehLoiVJkqaxURSANYJjSpIkaUhGUQD2lWSbJBckWZ3kpiSHTmGfS5JUkkEfVpEkSdIERlFYTXQJ+HTgAWA7YB/gW0kWVlXfp4qTHDai+CRJklptFCOAl/c2JJkLHAKcVFWrquoK4ELg8H4HSPJI4GTghBHEJ0mS1GoDFYBJ9k+yc7P8uCSfTXJ2ku3X9qmqV/bZdXdgTVUt7mpbCOw5wak+CJwJ3D5IfJIkSZrcoCOAZwBrmuX/BWxM56GPsybZb0tgRU/bCmCr3o5J5gP7Af8wWTBJjk6yIMmCZcuWTdZdkjYIc5Wk6WbQe+x2rKolzUMZLweeSOe+vlsn2W8VMK+nbR6wsrshyUZ0isy/rqoHk4d/o0xVnUVTfM6fP9+njyVNS+YqSdPNoCOA9ybZDngh8IuqWtW0bzzJfouBOUl262rbmz+eVm4eMB84L8ntwI+b9luSPH/AWCVJktTHoCOA/0CnKNsEeHvTth/wq4fbqapWJzkfODXJW+g8BXwQ8NyeriuAHbrWnwBcBewLeN1EkiRpCAadC/hDSS6g80DH9U3zUuAtU9j9WOBs4E7gLuCYqlqUZCfgF8AeVbWErgc/kmzWLN5RVQ8OEqskSZL6G/g9e91P8ibZn04x+P0p7Hc3cHCf9iV0HhLpt8+NOLWcJEnSUA36GpjLkuzXLL8L+BLwxSTvHkVwkiRJGr5BHwLZC7iyWT4KeBHwbOCtQ4xJkiRJIzToJeCNgEqyK5Cq+iVAkq2HHpkkSZJGYtAC8ArgE8DjgAsAmmJw+ZDjkiRJ0ogMegn4SOAe4BrgvU3bU4CPDS0iSZIkjdSgr4G5C3h3T9u3hhqRJEmSRmrQp4A3TnJKkhuS/Lb55ylJNhlVgJIkSRquQe8BPA14Jp2nfm+iMxfwSXSmcHvHcEOTJEnSKAxaAL4G2Lu5FAxwbZJ/AxZiAShJkjQjDPoQyESzcjhbhyRJ0gwxaAH4FeAbSV6e5KlJDgS+1rRLkiRpBhj0EvAJwN8CpwM7AEvpTAf3viHHJUmSpBGZtABM8uKepu81nwDVtD0PuGSYgUmSJGk0pjIC+KkJ2tcWf2sLwV2GEpEkSZJGatICsKp2HkcgkiRJGo9BHwKRJEnSDGcBKEmS1DIWgJIkSS1jAShJktQyYysAk2yT5IIkq5PclOTQCfq9IcnVSe5NckuS05IM+r5CSZIkTWCcI4CnAw8A2wGHAWcm2bNPvy2AtwOPAZ4FvAQ4bkwxSpIkzXpjGVlLMhc4BNirqlYBVyS5EDgcOLG7b1Wd2bW6NMnngf3HEackSVIbjGsEcHdgTVUt7mpbCPQbAez1AmDRSKKSJElqoXEVgFsCK3raVgBbPdxOSd4IzAc+PMH2o5MsSLJg2bJlQwlUkobNXCVpuhlXAbgKmNfTNg9YOdEOSQ4G/g54RVUt79enqs6qqvlVNX/bbbcdVqySNFTmKknTzbgKwMXAnCS7dbXtzQSXdpMcCPwj8Kqq+tkY4pMkSWqNsRSAVbUaOB84NcncJPsBBwHn9vZN8mLg88AhVXXVOOKTJElqk3G+BuZYYHPgTuCLwDFVtSjJTklWJdmp6XcS8Ejg2037qiQXjTFOSZKkWW1sL1iuqruBg/u0L6HzkMjadV/5IkmSNEJOBSdJktQyFoCSJEktYwEoSZLUMhaAkiRJLWMBKEmS1DIWgJIkSS1jAShJktQyFoCSJEktYwEoSZLUMhaAkiRJLWMBKEmS1DIWgJIkSS1jAShJktQyFoCSJEktYwEoSZLUMhaAkiRJLWMBKEmS1DIWgJIkSS1jAShJktQyFoCSJEktM7YCMMk2SS5IsjrJTUkOfZi+70hye5IVSc5Osum44pQkSZrtxjkCeDrwALAdcBhwZpI9ezsleTlwIvAS4EnALsAp4wtTkiRpdhtLAZhkLnAIcFJVraqqK4ALgcP7dH8D8KmqWlRVvwHeBxw5jjglSZLaIFU1+pMkTwN+UFWbd7UdB7ywql7V03ch8MGqOq9ZfwywDHhMVd3V0/do4Ohm9U+Aa0f3LTRLPQZYvqGD0Iy0vKoOnEpHc5WGwFylddU3V80Z08m3BFb0tK0AtppC37XLWwEPKQCr6izgrCHFqBZKsqCq5m/oODS7mau0vsxVGrZx3QO4CpjX0zYPWDmFvmuX+/WVJEnSgMZVAC4G5iTZrattb2BRn76Lmm3d/e7ovfwrSZKkdTOWArCqVgPnA6cmmZtkP+Ag4Nw+3c8B3pxkjyRbA38LfGYccaqVvCwnaSYwV2moxvIQCHTeAwicDbyMzr18J1bVF5LsBPwC2KOqljR93wm8C9gc+Crw1qr63VgClSRJmuXGVgBKkiRpenAqOEmSpJaxAJQkSWoZC0DNOMOaV3qy4yR5SZJfJbkvyaVJnti1bf+mbUWSG0fyRSXNaOYqTWcWgJqJhjWv9ITHaWagOR84CdgGWACc17XvajoPNR0/xO8laXYxV2na8iEQzSjNvNK/AfaqqsVN27nA0qo6safvF4Abq+rdzfpLgM9X1faTHaeZuuvIqnpu13mXA0+rql91neOlwD9V1ZNG+sUlzSjmKk13jgBqptkdWLM2ETYWAn/0q7ppW9jTb7skj57CcR6yb/Muy+snOI8k9TJXaVqzANRMM6x5pSc7ziDnkaRe5ipNaxaAmmmGNa/0ZMcZ5DyS1MtcpWnNAlAzzbDmlZ7sOA/Zt7mvZtcJziNJvcxVmtYsADWjDGte6Skc5wJgrySHJNkMeA9wzdqbqpNs1LRv3FnNZkk2GdHXljTDmKs03VkAaiY6ls480XcCXwSOqapFSXZKsqqZX5qquhg4DbgUuKn5nDzZcZp9lwGHAB+g8wTes4DXde37AuB+4NvATs3yv4zk20qaqcxVmrZ8DYwkSVLLOAIoSZLUMhaAkiRJLWMBKEmS1DIWgJIkSS1jAShJktQyFoCSJEktYwGoaS/JJ5OcNMW+n0ny/hHFUUmePIpjS5rZzFOaaeZs6ACkyVTVW4d1rCQF7FZV1w3rmJJkntJM4wigJElSy1gAaqSSvDHJN7rWr0vy5a71m5Psk+QpSb6b5O4k1yZ5bVefh1wuSXJCktuS3JrkLX0ueWyd5FtJVib5UZJdm/2+32xf2EzD9N8nif34rvO8qWfbpkk+nGRJkjuayz+bd20/KMlPk9yb5PokB3b9PX7ZxHZDkr/o2ufnSV7Vtb5xkuVJ9pn0Dy1pnZmnzFNtZAGoUbsMeH46E5I/js6E5PsBJNkF2BL4d+C7wBeAxwJ/DpyRZM/egzUJ6p3AS4EnAy/sc84/B04BtgauozNHJlX1gmb73lW1ZVWdN1HQzXmOA14G7Nacr9uHgN2BfZo4dqQzCTtJnklncvfjgUfRmYvzxma/O4H/CswD3gh8NMnTm23nAK/vOscrgduq6qcTxSlpKMxT5qnWsQDUSFXVDcBKOgnohcB3gKVJntKsX04n0dxYVZ+uqger6t+ArwL/rc8hXwt8uqoWVdV9dBJor/Or6qqqehD4fHPuQa09z8+rajXw3rUbkgQ4CnhHVd1dVSuBD/KHCdjfDJxdVd+tqt9X1dKq+lXz9/hWVV1fHZfRmZT9+c1+nwNemWRes344cO46xC5pAOYp81Qb+RCIxuEy4EV0foFeBtxDJ6k+p1l/IvCsJPd07TOH/kllB2BB1/rNffrc3rV8H51f74PaAbi6a/2mruVtgS2Aqzs5FoAAj2iWnwB8u99Bk7wCOJnOr/KNmuP8DKCqbk3yr8AhSS4AXgH89TrELmlw5qm1ncxTrWABqHG4DHgVsDOdX6D3AIfRSayfoHPp4rKqetkUjnUb8Piu9ScMNdKHnqf72Dt1LS8H7gf2rKqlffa9Gdi1tzHJpnRGDI4Avl5V/5Hka3SS8lqfBd5C5/+bP5zg+JKGzzyFeapNvASscbgM2B/YvKpuoXM55UDg0cBPgG8Cuyc5vLmheOMkz0jy1D7H+jLwxiRPTbIFzf0sA7gD2GUK/b4MHJlkj+Y8J6/dUFW/B/6Rzn0xjwVIsmOSlzddPtXE+JLmnqIdm0tJmwCbAsuAB5tf2Qf0nPdrwNPp/KI+Z8DvJmndmafMU61iAaiRq6rFwCo6CZWquhe4AfjXqlrT3JtyAJ17U26lc2nkQ3SSUO+xLgI+DlxK58bpHzabfjfFcN4LfDbJPd1P8E1wnv8NXNKc55KeLu9q2q9Mci/wf4E/afa9iubGaWAFzeWj5nu+jU7S/g1wKHBhz3nvp/Pre2fg/Cl+J0nryTxlnmqbVNWGjkFaZ82v758DmzY3U894Sd4D7F5Vr5+0s6Rpzzyl6cgRQM04Sf4sySZJtqbzC/wbsyipbkPn6byzNnQsktadeUrTnQWgZqK/oHN/yvXAGuCYdTlIknc3L1rt/Vw0zGAHiOcoOjdmX1RV35+sv6RpzTylac1LwJIkSS3jCKAkSVLLWABKkiS1jAWgJElSy1gASpIktYwFoCRJUsv8f6hb2NbuiOFxAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
    " + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "g = sns.FacetGrid(best_model, col=\"learning rate\", sharey=True, aspect=1.5, margin_titles=True)\n", + "g.map(sns.barplot, \"learning rate\", \"loss_mean\", order = [0.01, 0.1])\n", + "plt.rcParams.update({'font.size': 10})\n", + "g.map(sns.barplot, \"weight_decay\", \"loss_mean\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT0AAADICAYAAACaszaDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYqElEQVR4nO3de7gdVXnH8e+PhBSF4lNIQ0IwhEKr5aIgEQGNgjQaQSkCCggtASUELGojpGoRBAuREG5KaUxaGrUCkUvrjXuogIBAogW53xIQQkKCoFwSCOTtH2sdmOyz9zl7TvY+t/l9nmee7JlZs+adM+e8WWtm9hpFBGZmVbFeXwdgZtabnPTMrFKc9MysUpz0zKxSnPTMrFKc9MysUpz0KkrSLySd3wf7HSspJI1rYZ175DqHt6pOG7yc9KzH+lGyuRUYBTzbx3F0SdLGkr4taYmkVyQ9IunTDcp+Lf9su/2PSdIOkm6UtFLSU5JOkqTWH8HgMLSvAzBbVxHxKrC0r+PoiqT1gWuB54BPA08CWwCv1Cm7K3AUcHcT9W4MXAfcBLwXeAcwF3gJOKs10Q8ubulV21BJ50l6Lk9nSnrjd0LSYZLulPSCpGckXSppdF43FvjfXHR5bpXMzesk6cuSHs4tmiclTa/Z95aSrpP0sqT7JE3oKlBJH5T0K0kvSvqDpNslbZ/XrdXilLQ4z9dOY/P6t0manY/phdxKall3u4EjgBHA30bELyNicf73zprjfBvwQ+CzpATZnUOBtwKHR8Q9EXE5cAYw1a29+pz0qu1Q0u/AbsDRwGTgS4X1w4CTgXcDHweGAxfndb8DDsiftyN1L7+Y508Hvg5Mz+s+lcsXnQZ8O9d9J3CJpI3qBSlpKPBj4Je5/PuA84DXGxzXe3M8HdPPgAeAZTkR/BwYnY9pJ1Ir6QZJoxrUh6RZOeF2NY1ptD2wH3AL8B1JS3Oi/0ZuARbNBi6LiBu6qKtoN+DmiFhZWHYNsDkwtsk6qiUiPFVwAn4BPASosOxE4MkutnknEMAWeX6PPD+8UGYjYBUwpUEdY/M2RxeWjc7LPtBgm03y+g81WN8pjsK6fwJWAFvn+Q8DLwJvqSn3f8C0Lo59BLBNN9PQLrZ/IP9cLgR2Jv2HsRSYWShzFLAQGFY4R+d3cx6vBS6sWTYm/zx26+vfs/44+Zpetf0q8l9JdhvwTUkbR8QfJb2H1NLbkZR4OrpLY0jXpOrZFvgTYH43+y5er1qS/x1Rr2BE/D53na+RND/XfWlE1LYe1yLpE8ApwEcj4tG8eGdSd3B5Te9vA2DrRnVFxDPAM13trxvr5e2PiojXgYWSNgXOkXQC8FekFvL4SNcoy6gdNUQNlhu+kWENSNqQ1E26Hvg70h/scOBmUre34aZN7mJ1x4eIiJyAGl5uiYgjJJ0LTAT2BU6TtF9EXNMg/u1J18b+ISJuLKxaD1gGjK+z2R8b7V/SLOCwhkeTbBsRTzRY9zSwOie8DveTEvBwUjd1OHBPIRkPAT4oaQqwYUR0uulBai2OrFnW8Z/Hsm7irSQnvWp7nyQVWnu7AktyK29n0h/h1yJiEYCk/Wu272iRDCksu490R3Iv4OFWBhsRdwF3AWdIugo4nJSY15JvaPwEmBMR/16z+tfAZsCaiHisxO5PAmZ2U2ZJF+tuAT4jab2IWJOX/RXwMqn7/T/Agppt/pP0MzydN3/WtW4j/Tw2iIhVedmEHMvibuKtpr7uX3vqm4l0vegF0g2BdwAHAs8Dx+f1f066BnUW8BfAPsC9pC7THrnMaGANcGQuv1FefgbpzuMRpC7jLsAxed3YXMe4mngCOLBBrFsB3wJ2B7YE9gSeAk7M6/egcE0PuJF002NzUiuoYxpCaoneDPwW+FiuezdSN3h8G3/ebye1JL+Tf94fJV0iOLObc3R+zbLpwPzC/NtIrb1LgO2B/fN+vtzXv2P9derzADz10YlPf1CzgPNzsnsuJ7ghhTIHAY/m5HdH/kN9I+nlMl8ndd3WAHPzsvWArwCPkVoovwNOy+t6kvQ2A67Iie4V4AlgBrB+Xl+b9KLBNDav/1NSsn+yEN8l5JsdbfyZ70p6kHolsAg4lXzTootzVJv05gKLa5btQLoDvSqfi5Mp3KDytPak/EMzM6sEP6dnZpXipGdmleKkZ2aV4qRnZpXipGdmlVLph5MnTpwYV199dV+HYWat1/CbQZVu6a1YsaKvQzCzXlbppGdm1eOkZ2aV4qRnZpXipGdmlVLpu7c9sfMJ3+/rEAxYeObf93UINkC5pWdmleKkZ2aV4qRnZpXia3pmg9S0adNYunQpI0eOZMaMGX0dTr/hpGc2SC1dupSnnnqqr8Pod9y9NbNKcUvPrIGB/njSn654gSHAEyteGPDH0spHlNzSM7NKcdIzs0px99ZskFozbMO1/rXESc9skHrpLz/S1yH0S+7emlmlOOmZWaU46ZlZpTjpmVmlOOmZWaU46ZlZpTjpmVmlOOmZWaU46ZlZpfR60pN0rKRFklZJWihpfBdlvyEpGkwjcpk9Gqx/Z+8dlZkNFL36NTRJBwHnAccCv8z/XiVp24h4os4mM4FZNcsuASIinqlZvh3w+8L88tZEbWaDSW+39KYCcyNiTkTcHxHHAU8Dx9QrHBEvRsTSjglYHxgPzKlT/Jli2Yh4vW1HYWYDVq8lPUnDgJ2Ba2tWXQvs3mQ1nwWeBy6vs26BpKclzZe0Z48DNbNBrTdbesOBIcCymuXLgJHdbSxpPeBI4PsR8UphVUdL8QBgf+BBYL6kDzaoZ7KkBZIWLF/uHrBZ1ZS6pidpA+CLwF7ACGqSZkS8q4lqorbaOsvq+RjwduDfa/b5ICnRdbhN0ljgeOCmTjuPmA3MBhg3blwz+zWzQaTsjYwLgE8ClwK30lyy6rACeJ3OrboRdG791TMZuDUi7m2i7O3AwSViM7OKKJv09gM+FRHXl91RRLwqaSEwgZQ0O0yg/jW6N0jaHNgH+FyTu9uR1O01M1tL2aT3MvC7ddjf2cAPJN0B3AJMATYnP5YiaTqwS0TsVbPdkcBLwI9qK5T0JWAxcC8wDDiMlJwPWIc4zWyQKpv0ZgBTJR0TEWvK7iwi5knaFDgRGAXcA+wdEY/nIqOArYvbSBLpru0PI+LlOtUOIz3PNxpYSUp++0TElWXjM7PBr2zSm0B6Tm6ipPuA1cWVEbFvdxVExAWka4P11k2qsyyArbqobwYpGZuZdats0lsB/Hc7AjEz6w2lkl5EHNGuQMzMeoNHWTGzSik94ICkI4BDgDGkmwhviIi/aFFcZmZtUaqlJ+kE4CxgITAW+B/SHdhNgAtbHJuZWcuV7d4eBUyOiK+S7tyen+/YngVs2ergzMxarWzS2wK4I39eCWycP1+MHwY2swGgbNJbShotBeBxYLf8eRvKfQ/XzKxPlE16NwAdDyD/B3C2pP8F5gFXtDIwM7N2KHv3djI5UUbELEnPAe8nDRjw3RbHZmbWcmUfTl4DrCnMzyO18szMBoTSDydL2kHS+ZKukjQqL9tP0k6tD8/MrLXKPqf3EeBO0ogmHwbekldtDZzc2tDMzFqvbEvvm8DUiPgk8Gph+S+AXVoVlJlZu5RNetsB9cap+z3pWxlmZv1a2aT3HKlrW+s9wJPrHo6ZWXuVTXoXAWdK2oL0MPJQSR8ijVz8/VYHZ2bWamWT3onAItK3MTYC7iM9sPxL4LTWhmZm1npln9NbDRwq6SRgJ1LS/E1EPNyO4MzMWq30eHoAEfEo8GiLYzEza7ueDCL6SWBP0ku61+oeR8SnWxSXmVlblH04+SzS1852yIter5nMzPq1sjcyDgc+FRF7RsTBEXFIcWqmAknHSlokaZWkhZLGd1F2rKSoM02sKfehXNcqSY9JmlLyuMysIsomvZeBB3q6M0kHAecBp5NuhNwKXCVpTDebTiS9CLxjuqFQ51akB6ZvzXVOB74jyYOamlknZZPet4Bpknp0AwSYCsyNiDkRcX9EHAc8DRzTzXbPRsTSwlT8CtwUYElEHJfrnAN8Dzi+hzGa2SBWNnnNAT4OPCXpIdJ7Mt4QER9utKGkYcDOpAeZi64Fdu9mv1dI2gB4GDgnIi4rrNst11F0DXC4pPXzYzZmZkD5pDcLGA9cDSyj3BDxw4EhebuiZcDfNNjmRVKL7RbgNdKozfMkHR4R/5XLjASur1Pn0LzPp4srJE0mDYbKmDHd9arNbLApm/Q+DewfEdetwz5rE6XqLEsFI1aQ3rTWYYGk4cA04L+KRevUWW85ETEbmA0wbtw4v9fDrGLKXtNbATzVw32tID3WMrJm+Qg6t/66cjvwl4X5pQ3qfA14tmSMZjbIlU16JwOnStqo7I7yzYeFwISaVRNId16btSNrd1lvo3P3eAKwwNfzzKxW2e7tCcBYYJmkJ+h8I+Nd3Wx/NvADSXeQrtNNATYnXStE0nRgl4jYK88fnvfxG9K7OT4BfB74p0Kds4B/kHQu6eVE7wcmAU09N2hm1VI26V3WfZHGImKepE1Jo7WMAu4B9o6Ix3ORUaSh54tOBLYkdY0fAo4s3MQgIhZJ2hs4h/ToyxLgCxFx+brEamaDU9lRVk5pppykQ4CfRMRLdeq4ALigQf2Taua/R3rmrru4biQNZGpm1qXSb0Nr0neBzdpUt5lZj7Ur6an7ImZmva9dSc/MrF9y0jOzSnHSM7NKcdIzs0ppV9J7nJoHl83M+oNSz+lJ+nOAiFie53cADgLujYiLO8pFxPatDNLMrFXKtvR+RPoqGHm0k5uATwKzJH25xbGZmbVc2aT3LuBX+fOBwCMRsR3w98DRrQzMzKwdyia9t5AG9oQ0sslP8udfA29vVVBmZu1SNuk9DOwv6e3AR3hzmPbNgOdbGJeZWVuUTXqnAGcAi4FfRcTteflHScM/mZn1a2VHWbkiv65xc+CuwqrrAQ/lZGb9XulXOUbEMgrDu0vaBrgrIla1MjAzs3Yo1b2VdHoezRgl15EG9nxa0vvaEaCZWSuVvaZ3KPBg/vwx0vsqdgW+T3oRuJlZv1a2e7sZ8GT+vDfwo4i4Q9LvgQUtjczMrA3KtvSeJb2vAtIjKzfkz0PxwKFmNgCUbeldDlwk6SFgE+DqvHxH4JEWxmVm1hZlk95U0ggqY4BphRf/jAL+rZWBmZm1Q9nn9F4Dzqqz/JyWRWRm1kalx9OTtJmkUyVdJulSSadIGlFi+2MlLZK0StJCSeO7KLuHpB9LelrSy5LulnRknTJRZ3pn2WMzs8Gv7HN67yddu/sMsBJYRXqM5RFJuzWx/UHAecDpwE7ArcBV+Vse9ewO/JY0osv2pC70bEmfqVN2O1I3u2N6uPkjM7OqKHtNbyZwMTAlItYASFoPmEXq9u7ezfZTgbkRMSfPHydpInAM8NXawhFxes2if5O0J3AAcFHNumciYkWZgzGz6inbvd0ROKsj4QHkz2eTWm4NSRoG7MybI7N0uJbuk2XRxsBzdZYvyN3g+Tkxmpl1Ujbp/QHYqs7yreh+aKnhwBAK39vNlgEjm9m5pI8DewGzC4ufJrUUDwD2J31jZL6kDzaoY7KkBZIWLF++vJndmtkgUrZ7ewnwH5Kmka7HBfAB0lfQLu5qw4KomVedZZ3k64kXAV+IiDveqCziQd78ahzAbZLGAseThrNfe+cRs8lJc9y4cd3u18wGl7JJbxopSV3Im9/CeJV0g+Er3Wy7Anidzq26EXRu/a1F0geAK4GTIqKZ5wFvBw5uopyZVUyp7m1EvBoRXwT+jHR9b0dgk4j4x4h4tbttgYXAhJpVE0itxrpyN/Uq4JSIOLfJUHckdXvNzNbSbUtP0k+aKANAROzbTdGzgR9IugO4BZhCGpB0Vq5nOrBLROyV5/cAfg5cAPxQUkcr8fXCayi/RBrJ+V5gGHAYsB/pGp+Z2Vqa6d4+26qdRcQ8SZsCJ5KepbsH2DsiHs9FRgFbFzaZBLyVdH3u+MLyx4Gx+fMw0qM0o0nPDt4L7BMRV7YqbjMbPLpNehFxRCt3GBEXkFpu9dZNqjM/qV7ZQpkZwIzWRGdmg13pr6GZmQ1kTnpmVilOemZWKU56ZlYpTnpmVilOemZWKU56ZlYpTnpmVilOemZWKU56ZlYpTnpmVilOemZWKU56ZlYpTnpmVilOemZWKU56ZlYpTnpmVilOemZWKU56ZlYpTnpmVilOemZWKb2e9CQdK2mRpFWSFkoa3035HSTdKGmlpKcknaSOF+2+WeZDua5Vkh6TNKW9R2FmA1WvJj1JBwHnAacDOwG3AldJGtOg/MbAdcAy4L3AF4ATgKmFMlsBV+a6dgKmA9+R5Jd9m1knvd3SmwrMjYg5EXF/RBwHPA0c06D8oaSXfR8eEfdExOXAGcDUQmtvCrAkIo7Ldc4BvsfaLwc3MwN6MelJGgbsDFxbs+paYPcGm+0G3BwRKwvLrgE2B8YWytTWeQ0wTtL66xKzmQ0+vdnSGw4MIXVVi5YBIxtsM7JB+Y51XZUZmvdpZvaGoX2wz6iZV51l3ZWvXd5MmbRCmgxMzrMvSnqwi30PVsOBFX0dxLrQzMP7OoSBYsCfa+jR+b46IibWW9GbSW8F8DqdW3Uj6NxS67C0QXkK2zQq8xrwbG2FETEbmN1cyIOTpAURMa6v47D287nurNe6txHxKrAQmFCzagLpzms9twHjJW1QU34JsLhQ5m/q1LkgIlavS8xmNvj09t3bs4FJkj4n6a8lnUe6KTELQNJ0SfML5S8CXgbmStpe0v7AV4CzI6Kj6zoL2ELSubnOzwGTgJm9dExmNoD06jW9iJgnaVPgRGAUcA+wd0Q8nouMArYulP+DpAnAvwILgOeAs0jJs6PMIkl7A+eQHn1ZAnwhP95i9VW6e18xPtc19GaDycxs8PN3b82sUpz0zKxSnPQGuFYP4CBplKSLJD0g6XVJc9t+ENZjZc6/pA0kzZV0t6TVkn7Ri6H2G056A1g7BnAA/oT0TOW3gNvbFryts7Lnn/SNqFXA+cDPeyXIfsg3MgYwSbcDd0fEUYVlDwOXRcRX65Q/hjRgw2Yd32eWdCLprvcWUfPLIOlnwIqImNS+o7CeKnv+a7Y9H9g+IvZob5T9j1t6A1QbB3CwAaCH599w0hvI2jWAgw0MPTn/hpPeYNCOARxs4Ch7/ivPSW/gatcADjYw9OT8G056A1YbB3CwAaCH599w0hvo2jGAA5J2lLQjsDGwSZ7ftpeOyZpX9vwjadt8bocDGxXOdWX0xSCi1iLtGMAh+03N/CeAx/Ed3n6l7PnPrgS2LMx3nGtREX5Oz8wqxd1bM6sUJz0zqxQnPTOrFCc9M6sUJz0zqxQnPTOrFCc965E8GOXP+jqODpJC0oF9HYf1f056NliMAn7a10E0Iukbku7p6zjMSc/6MUlDi0PZdyUilkbEK+2OqVYe184GECc9awkl0yQ9mt+/8VtJh9WU+ZakB/P6xZJmFAc/6GgNSZok6VHgFWDD3HWdLOlSSS9JeqxO3W90byWNzfMHSLpO0suS7stfwStus0+OZ5WkmyQdnLcb28VxLs5xXijpeeCH3R2bpEnAycB2uf7Iy5D0NkmzJT0j6QWl95eM6+FpsCY46Vmr/AvwWeDzwLbAdOC7kvYplHkJOBL4a+BY4GDgn2vq2Qr4DPAp4N2kdzoAnAT8OC+bB1woaUu6dhrw7bzNncAlkjYCyO+RuIL0roh353IzmjzWqcADwDjga00c2zzSd5wfJHXDRwHzciv258Bo4OOk91zcBNwgaVSTsVhZEeHJU+kJmAv8LH/eEFgJjK8pcy5wZRd1TAEeKcx/A1hNeodHsVwA0wvzQ0mjxRxWU+bA/Hlsnj+6sH50XvaBPD8duJ/8/fO87Gu5zNguYl4M/LSJn0+9Y7unpsyHgReBt9Qs/z9gWl+f48E6eZQVa4VtgQ2AqyUVR7BYn8I4fbn7+SVgG2Aj0nDnQ2rqejIi6g2CeXfHh4h4TdJy3hwAtZG7C5+X5H87tnkncGfkLJM1+/a3BbULmjy2WjsDbwWW11y63IDOo6NYizjpWSt0XCb5BPBEzbrVAJJ2BS4BTgH+EXge2BeYWVP+pQb7WF0zH3R/eeaNbSIicmLp2GZdhlVfK8YSx1ZrPdIox/XeVfvHHsZm3XDSs1a4j3TTYcuIuKFBmfcDT0XENzsWNHFNrp3uB/62ZtkuPayrmWN7lc4tv18DmwFrIuKxHu7bSnLSs3UWES9ImgnMzBfnbyJ18XYl/UHPBh4CRks6lDRs/UeBQ/oqZtLowlNz3HOA7YCj87qyLcBmjm0xsKWk95Bawy8A1wO3AD+WNI10c2QkMBG4PiJuLntQ1j3fvbVW+TrpYv3xwL3AdcABwCKAiPgpcCbp5sbdpHc5nNQHcZLjeTzHty9wF6lbekpevarRdg3qaubYLieNWjwfWA4ckq8n7g3cQEq8DwI/At7Bm9cgrcU8crJZJumLwKnAn0XEmr6Ox9rD3VurLEmfJz2/t5zUFf86MNcJb3Bz0rMq24b0bN6mwJOk63yn9mlE1nbu3ppZpfhGhplVipOemVWKk56ZVYqTnplVipOemVWKk56ZVcr/A2eOPGZwJypBAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
    " + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.rcParams.update({'font.size': 14})\n", + "g = sns.FacetGrid(best_model, col = \"batch size\", sharey=True, aspect=1.5, margin_titles=True)\n", + "g.map(sns.barplot, \"learning rate\", \"loss_mean\", order = [0.01, 0.1])\n", + "plt.rcParams.update({'font.size': 14})\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAADICAYAAADBREMvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgkUlEQVR4nO3debQlZXnv8e8PkKAxZl1FpAXbxiEqiIHQDqCNEGlFjC4ZInGkRSSAAb0EWWKIGm8EB0BRgwhqWhNRVEw0KpNoiwpCuvVqUMQJcAAaGjS5CA0Iz/2j6uBm9z7D7j5D71Pfz1rvOruqnnrrraZ5Tj+7qt5KVSFJkiRJUhdtMtcDkCRJkiRprlgUS5IkSZI6y6JYkiRJktRZFsWSJEmSpM6yKJYkSZIkdZZFsSRJkiSpsyyK1XlJFiWpJIvneiySNBHzlaRRZf7SxsyiWJojSfZofzn0t8f3xe2f5AdJ7mh/7juFvhcm+Y8kv02yJsl7k2w+c2cjaT5LsiDJWUl+mOTuJMvHiZs0XyU5IsnVSdYmWZVkyRSOv2OSryW5PcmvkrwpSabh1CTNc1PJX0mWjfNvsi364sxf85RFsTT3dgAW9LQfj21IsitwNvBxYKf256eTPHW8zpJsCnwR+CNgCfBi4ADg5JkZvqQO+ANgDfB24LJBAVPJV0kOBE4FTgB2Bi4Bzk2ycLwDJ3kQcCGwGngycBTweuDoDT0pSZ0waf5q3cZ9/z22oKrWjm00f81vFsWaU2kcm+Sn7Tdo/5XkZT3bx261eUmSb7TfzP0wybP7+tk9yWXt9tVJ3t17ZbQ9zt8m+XF7BeOXSU7sG84jk1yY5Lb2CsfSGT79MTdW1Q097e6eba8DvlpVb6uqK6vqbcCKdv14nk1TaL+8qr5dVRcCxwKvbpOzpPXQ5XxVVddU1VFVtRy4ZZyw1zF5vjoaWF5VZ7YxRwLXA4dPcPiXAg8ADqqqK6rqHOAdwNFebZGmxvw1af5qQ+/z77Eb+rabv+Yxi2LNtX8EXgW8BtgeOBH4YJLn9cW9E3gvzdWHC4HPJdkGoP15LvAdmm/uXkVzdbQ3CZ8A/H27bgfgL4Ff9B3jbe0x/hT4T+CTSR443sCTvDHJrZO0SW+rAVYmuT7JRUn27Nu2K3BB37rzgd0m6G9X4Mqq6j2/82m+Kd1lCuORNJj5amIT5qv2H867DIi5gMlz2ter6va+fh8OLNqA8UpdYv6a3P2TXNsW8l9IsnPPGMxf811V2Wxz0oA/BG4HlvStfw/wpfbzIqCAv+vZvgnwI+Af2+W3AT8BNumJWQbcQfPt3AOBtcBh44xj7Bh/3bNum3bdMyYY/4OBx0zS7j/B/o8DDqNJsrsCpwH3ALv3xNwJvKJvv1cAd0zQ7xnAV/rWBfgd8OK5/u9us41i63q+6uvrCzRXS/rXT5ivaP4RWL05rl3/JuCqCY53AfCRvnUL2752neu/Gzbbxt7MX/fpa7z8tStwEM2XAUuAz9DcTv3Ydrv5a563zZDmzvbAFsB5Sapn/f2Aa/piLx37UFX3JLms3R/gCcClVXVPT/w3gM1pEuUWNFdJL5pkPN/r+Xxd+3Or8YKr6hYmvg1nQlV1FXBVz6pLkywCjgEu7g3t2zUD1q3T/ZDrJU2s0/lqCFPJV9OR0zLOeknrMn9Noqoupefck1wC/F/gSJrngO8N7dvV/DVPWBRrLo3dvv984Od92+4aop+JElLx++QzmXuPWVXVPuox7iMGSd4IvHGSPp9bVV+f4vGhmQDir3qWbwC27ovZimbChvHcADy9b92WwKaT7CdpfOaryU2Wr9YAd08SM0y/TLKfpIb5a0hVdXeSlcBj21Xmr3nOolhz6Qc0t9w8sqq+Mkns04CvQDOJA/AUmltbxvp5UZJNer69fAbNrXw/pSkG7wCeRc/MztPgdOBTk8T8asg+d6KZtGHMpcBS4F0965bSzHg4nkuB45NsW1W/7NnnDmDVkOOR1DBfTW7CfFVVdyZZ1a77dF/MOZP0+44kW9TvZ4JdSnOF6ZoNHLPUBeavIbXn/iTgu2D+6oS5vn/b1u1GM/HDzcDBNLfe7ETznO2h7fZFNN8+/oLmtUKPo5kOfy2wbRuzDfBbmqT5BOB5NN/MndxznHcAvwZeCTyaJskf3neMxX1jK+CAGTz31wEvpPkWcgeaSSkK2K8nZjeaZ4GPAx7f/rwLeGpPzN8AP+xZ3hT4L5pfajsDe9H8snjfXP/3ttlGuXU5X7XH2KltFwOfbz9v37N9KvnqQJp/QB/Snv+pwK00/1gfizkRuKhn+Y/bP6NPAk8E9gP+B/jbuf47YbONSjN/TZq/3gw8B3hUu+0jbf56Sk+M+WsetzkfgK3bjeZWmyP5/beYN9HMdri03T6WQF9Kc7VhLc1zuM/t62d3mluP76C5HeXdwB/0bN8EeAPwszah/QJ4W98xZrsoPpZmworbaZ6V+Tqwz4C4A4AftuO+kp6iud3+Fpo7kHrXLaSZTOI2ml+C7+v987DZbMO3LuernmP0t2v6YibMV23METRXSMbuXumfuGb5gH53pPnH7Fqau2neDGSu/07YbKPSzF8T56/2PK5tz+tGmhmi15kIy/w1f1va/1jSRqmdeOpq4MlVtXKOhyNJ4zJfSRpV5i91ne8pliRJkiR1lkWxJEmSJKmzvH1akiRJktRZXimWJEmSJHWWRbEkSZIkqbM2m+sBzKW99967zjvvvLkehqRuyXR1ZA6TNAemJYeZvyTNgXHzV6evFK9Zs2auhyBJ680cJmlUmb8kbUw6XRRLkiRJkrrNoliSJEmS1FkWxZIkSZKkzrIoliRJkiR1Vqdnn5YkSVI37PL6j831ECS1Vr3rFXM9hPvwSrEkSZIkqbMsiiVJkiRJnWVRLEmSJEnqLItiSZIkSVJnOdGWJGlKnKRG2rhsbBPVSNKo8kqxJEmSJKmzLIolSZIkSZ1lUSxJkiRJ6iyfKZ5lPpMnbVx8Jk+SJKnbvFIsSZIkSeosi2JJkiRJUmdZFEuSJEmSOsuiWJIkSZLUWRbFkiRJkqTOsiiWJEmSJHWWRbEkSZIkqbMsiiVJkiRJnWVRLEmSJEnqLItiSZIkSVJnWRRLkiRJkjpr1oviJEckuTrJ2iSrkiyZIPYtSWqctlUbs8c42x8/e2clSZIkSRpFm83mwZIcCJwKHAF8o/15bpLtq+rnA3Y5CTi9b90ngaqqG/vW7wDc0rN80/SMWpIkSZI0X832leKjgeVVdWZVXVlVRwLXA4cPCq6qW6vqhrEG3A9YApw5IPzG3tiqunvGzkKSJEmSNC/MWlGcZHNgF+CCvk0XALtNsZtXAb8BzhmwbWWS65NclGTP9R6oJEmSJKkzZvNK8ZbApsDqvvWrga0n2znJJsDBwMeq6o6eTWNXmvcH9gOuAi5Ksvs4/RyaZGWSlTfd5B3WkkaLOUzSqDJ/SdpYDfVMcZItgNcCzwK2oq+orqonTaGb6u92wLpBngs8AvhQ3zGvoimEx1yaZBFwDHDxOgevOgM4A2Dx4sVTOa4kbTTMYZJGlflL0sZq2Im2TgP2BT4NXMLUitkxa4C7Wfeq8Fase/V4kEOBS6rq+1OIvQz4qyHGJkmSJEnqoGGL4hcCf1lVXx72QFV1Z5JVwFKaonrMUgY/I3yvJA8HngccMsXD7URzW7UkSZIkSeMatii+DfjFBhzvFOBfklwOfBM4DHg47WuXkpwIPKWqntW338HAb4FP9XeY5HXANcD3gc2Bl9EU7/tvwDglSZIkSR0wbFH8TuDoJIdX1T3DHqyqzk7yEOB4YAFwBbBPVV3bhiwAHt27T5LQzDr98aq6bUC3m9O8z3gb4Haa4vh5VfWlYccnSZIkSeqWYYvipTTvCd47yQ+Au3o3VtULJuugqk6jeTZ50LZlA9YVsN0E/b2TpliXJEmSJGkowxbFa4B/m4mBSJIkSZI024YqiqvqlTM1EEmSJEmSZtsmk4dIkiRJkjQ/DXv7NEleCbwYWEgzydW9qupR0zQuSZIkSZJm3FBXipO8HjgZWAUsAv6dZgbpBwMfmeaxSZIkSZI0o4a9ffrVwKFVdRzNzNPvb2ecPhl45HQPTpIkSZKkmTRsUbwtcHn7+XbgQe3nTwD7T9egJEmSJEmaDcMWxTcAW7afrwV2bT8/BqjpGpQkSZIkSbNh2KL4K8AL2s8fBk5J8lXgbOCz0zkwSZIkSZJm2rCzTx9KW0hX1elJfg08HTgH+OA0j02SJEmSpBk1VFFcVfcA9/Qsn01zlViSJEmSpJEz7O3TJNkxyfuTnJtkQbvuhUl2nv7hSZIkSZI0c4Z9T/Gzgf8EtgH+HLh/u+nRwJund2iSJEmSJM2sYa8U/x/g6KraF7izZ/0K4CnTNShJkiRJkmbDsEXxDsCXBqy/BXjwhg9HkiRJkqTZM2xR/GuaW6f7/Rnwyw0fjiRJkiRJs2fYovgs4F1JtgUK2CzJM4GTgI9N9+AkSZIkSZpJwxbFxwNXA9cCDwR+AHwF+AbwtukdmiRJkiRJM2vY9xTfBbw0yZuAnWmK6u9U1Y9nYnCSJEmSJM2koYriMVX1U+Cn0zwWSZIkSZJm1dBFcZJ9gT2Brei7/bqqXjRN45IkSZIkacYN9UxxkpOBs4Ed21V397Wp9HFEkquTrE2yKsmSCWIXJakBbe++uGe2fa1N8rMkhw1zXpIkSZKkbhr2SvFBwF9W1efW52BJDgROBY6gmZzrCODcJNtX1c8n2HVv4Ls9y7f09LkdzbuTPwK8DHgGcFqSm6rqnPUZpyRJkiSpG4adffo24IcbcLyjgeVVdWZVXVlVRwLXA4dPst/NVXVDT7uzZ9thwHVVdWTb55nAR4FjNmCckiRJkqQOGLYofjtwbJL1eRZ5c2AX4IK+TRcAu02y+2eT3Jjkm0kO6Nu264A+zwcWJ7nfsOOUJEmSJHXHsMXtmcBfAL9K8iPgrt6NVfXnE+y7JbApsLpv/Wpgr3H2uZXmiu83gd8BLwDOTnJQVf1rG7M18OUBfW7WHvP6iU5IkiRJktRdwxbFpwNLgPNoCs9aj2P275Px+qmqNcDJPatWJtkSOBb4197QAX0OWk+SQ4FDARYuXDj1UUvSRsAcJmlUmb8kbayGLYpfBOxXVReux7HW0MxQvXXf+q1Y9+rxRC4DXtmzfMM4ff4OuLl/56o6AzgDYPHixetT1EvSnDGHSRpV5i9JG6thnyleA/xqfQ7UTo61Cljat2kpcMkQXe3EfW+JvpR1b79eCqysqruQJEmSJGkcwxbFbwbemuSB63m8U4BlSQ5J8oQkpwIPp7ktmyQnJrloLDjJQUle0sY+LskxwGuA9/X0eTqwbZL3tHGHAMuAk9ZzjJIkSZKkjhj29unXA4uA1Ul+zroTbT1pop2r6uwkDwGOBxYAVwD7VNW1bcgC4NF9ux0PPJLm1usfAQf3TLJFVV2dZB/g3TSvdroOOMp3FEuSJEmSJjNsUfyZDT1gVZ0GnDbOtmV9yx+leefwZH1+DfizDR2bJEmSJKlbhiqKq+ofphKX5MXA56vqt+s1KkmSJEmSZsGwzxRP1QeBh81Q35IkSZIkTYuZKoozeYgkSZIkSXNrpopiSZIkSZI2ehbFkiRJkqTOsiiWJEmSJHWWRbEkSZIkqbNmqii+FrhrhvqWJEmSJGlaDPWe4iQPBaiqm9rlHYEDge9X1SfG4qrqidM5SEmSJEmSZsKwV4o/BTwfIMmWwMXAvsDpSf52mscmSZIkSdKMGrYofhLwrfbzAcBPqmoH4BXAX0/nwCRJkiRJmmnDFsX3B25tP+8FfL79/G3gEdM1KEmSJEmSZsOwRfGPgf2SPAJ4NnBBu/5hwG+mcVySJEmSJM24YYvifwDeAVwDfKuqLmvXPwf4zjSOS5IkSZKkGTfU7NNV9dkkC4GHA9/t2fRl4JzpHJgkSZIkSTNtqKIYoKpWA6vHlpM8BvhuVa2dzoFJkiRJkjTThrp9OskJSQ5qPyfJhcCPgOuTPHUmBihJkiRJ0kwZ9pnilwJXtZ+fC+wEPA34GPD26RuWJEmSJEkzb9jbpx8G/LL9vA/wqaq6PMktwMppHZkkSZIkSTNs2CvFNwOPbD8/G/hK+3kzINM1KEmSJEmSZsOwV4rPAc5K8iPgwcB57fqdgJ9M47gkSZIkSZpxwxbFRwPXAguBY6vqt+36BcAHpnNgkiRJkiTNtKFun66q31XVyVX12qr6Ts/6d1fVh6bSR5IjklydZG2SVUmWTBC7R5LPJbk+yW1Jvpfk4AExNaA9fphzkyRJkiR1z9DvKU7yMOA1wPZAAT8A/qmqbpzCvgcCpwJHAN9of56bZPuq+vmAXXYD/gt4J3A98BzgjCRrq+qsvtgdgFt6lm8a6sQkSZIkSZ0z7HuKn07z7PBLgNuBtTSvafpJkl2n0MXRwPKqOrOqrqyqI2mK3cMHBVfVCVV1fFV9s6p+VlUfAD4L7D8g/MaquqGn3T3MuUmSJEmSumfY2adPAj4B/ElVvbyqXg78CfBJ4OSJdkyyObALcEHfpgtorghP1YOAXw9Yv7K9zfqiJHsO0Z8kSZIkqaOGLYp3Ak6uqnvGVrSfTwF2nmTfLYFNgdV961cDW0/l4En+AngWcEbP6rErzfsD+wFXARcl2X0qfUqSJEmSumvYZ4r/G9iOpvDstR3wmyn2UX3LGbBuHe2t22cBR1XV5fd2VnVV33guTbIIOAa4eEA/hwKHAixcuHCKQ5akjYM5TNKoMn9J2lgNe6X4k8CHk7w0yXZJFiV5GXAmzW3VE1kD3M26V4W3Yt2rx/eR5BnAucCb2ueKJ3MZ8NhBG6rqjKpaXFWLH/rQh06hK0naeJjDJI0q85ekjdWwV4qPpbmy+5F23wB30ryj+A0T7VhVdyZZBSwFPt2zaSlwznj7tbdBfxF4S1W9Z4rj3InmtmpJkiRJksY1VFFcVXcCr01yHPBomqL4J1V12xS7OAX4lySXA98EDgMeDpwOkORE4ClV9ax2eQ+agvg04ONJxq4y311VN7UxrwOuAb4PbA68DHghg2eoliRJkiTpXpMWxUk+P4UYAKrqBRPFVdXZSR4CHA8sAK4A9qmqa9uQBTTF9phlwANong8+pmf9tcCi9vPmNLNib0PzmqjvA8+rqi9NNm5JkiRJUrdN5UrxzdN5wKo6jebK76BtywYsLxsU2xPzTuCd0zM6SZIkSVKXTFoUV9UrZ2MgkiRJkiTNtmFnn5YkSZIkad6wKJYkSZIkdZZFsSRJkiSpsyyKJUmSJEmdZVEsSZIkSeosi2JJkiRJUmdZFEuSJEmSOsuiWJIkSZLUWRbFkiRJkqTOsiiWJEmSJHWWRbEkSZIkqbMsiiVJkiRJnWVRLEmSJEnqLItiSZIkSVJnWRRLkiRJkjrLoliSJEmS1FkWxZIkSZKkzrIoliRJkiR1lkWxJEmSJKmzLIolSZIkSZ1lUSxJkiRJ6qxZL4qTHJHk6iRrk6xKsmSS+B2TfC3J7Ul+leRNSdIX88y2r7VJfpbksJk9C0mSJEnSfDCrRXGSA4FTgROAnYFLgHOTLBwn/kHAhcBq4MnAUcDrgaN7YrYDvtT2tTNwIvC+JPvP3JlIkiRJkuaD2b5SfDSwvKrOrKorq+pI4Hrg8HHiXwo8ADioqq6oqnOAdwBH91wtPgy4rqqObPs8E/gocMzMnookSZIkadTNWlGcZHNgF+CCvk0XALuNs9uuwNer6vaedecDDwcW9cT093k+sDjJ/TZkzJIkSZKk+W02rxRvCWxKcyt0r9XA1uPss/U48WPbJorZrD2mJEmSJEkDbTYHx6y+5QxYN1l8//qpxDQbkkOBQ9vFW5NcNcGxpfFsCayZ60Fow+Wkg2b7kOdV1d7ru7M5TNPEHDYPzEH+gg3IYeYvTRPz1zywseWv2SyK1wB3s+5V4a1Y90rvmBvGiadnn/Fifgfc3N9hVZ0BnDG1IUuDJVlZVYvnehzqHnOYpoM5THPB/KXpYP7STJi126er6k5gFbC0b9NSmpmjB7kUWJJki77464BremL2GtDnyqq6a0PGLEmSJEma32Z79ulTgGVJDknyhCSn0kyadTpAkhOTXNQTfxZwG7A8yROT7Ae8ATilqsZujT4d2DbJe9o+DwGWASfN0jlJkiRJkkbUrD5TXFVnJ3kIcDywALgC2Keqrm1DFgCP7on/7yRLgX8CVgK/Bk6mKa7HYq5Osg/wbppXO10HHNW+vkmaKd7+JWmUmcMkjSrzl6Zdfn/BVZIkSZKkbpnt26clSZIkSdpoWBRLkiRJkjrLoljzXpIjklydZG2SVUmWTBK/Y5KvJbk9ya+SvClJ+mKe2fa1NsnPkhw2oJ/9k/wgyR3tz337tu+e5PPtMSrJsmk5YUnzhvlL0qgyf2mUWBRrXktyIHAqcAKwM83rv85NsnCc+AcBF9K8B/vJwFHA64Gje2K2A77U9rUzcCLwviT798TsCpwNfBzYqf356SRP7TncA2kmm3stcPuGn62k+cT8JWlUmb80apxoS/NaksuA71XVq3vW/Rj4TFUdNyD+cOAdwMOq6vZ23fE0M5tvW1WV5B3AflX12J79PgTsUFW7tstnAw+uqqU9MV8GbqqqFw847q3A31TV8uk4b0mjz/wlaVSZvzRqvFKseSvJ5sAuwAV9my4Adhtnt12Br48l5Nb5NO/TXtQT09/n+cDiJPebJGa840rSvcxfkkaV+UujyKJY89mWwKY0t+L0Wg1sPc4+W48TP7ZtopjN2mNOFDPecSWpl/lL0qgyf2nkWBSrC/qfEciAdZPF969f3xifV5A0DPOXpFFl/tLIsCjWfLYGuJt1vx3cinW/RRxzwzjx9OwzXszvgJsniRnvuJLUy/wlaVSZvzRyLIo1b1XVncAqYGnfpqU0MxcOcimwJMkWffHXAdf0xOw1oM+VVXVXT8wwx5Wke5m/JI0q85dGkUWx5rtTgGVJDknyhCSn0kzacDpAkhOTXNQTfxZwG7A8yROT7Ae8ATilfj9V++nAtkne0/Z5CLAMOKmnn1OBP09yXJLHJzkO2BN4z1hAkgcm2SnJTjT/Ly5slwe+rkBS55i/JI0q85dGS1XZbPO6AUfQfMt4B803l7v3bFsOXNMXvyNwMbAWuB54M+3ry3pingl8u+3zauCwAcc9APghcCdwJc1rBHq370HzjEt/Wz7Xf2Y2m23jaOYvm802qs38ZRul5nuKJUmSJEmd5e3TkiRJkqTOsiiWJEmSJHWWRbEkSZIkqbMsiiVJkiRJnWVRLEmSJEnqLItiSZIkSVJnWRRr3kjyliRXDLnPiiTvn6kx9RzngCS+/0zSQOYvSaPK/KX5wKJY88lJNC91n1ZJrklyzHT3K0k9zF+SRpX5SyNvs7kegDRdqupW4Na5HockDcv8JWlUmb80H3ilWHMmyXOT/L8km7XLj01SST7QE/O2JBe2n7dP8sV2nxuTfCLJ1j2x97l9J8lmSd6d5Ndte3eSDyRZ0TeUTZKckGRN2+9JSTZp+1gBPBJ4Vzu2Kd2Ck+QVSa5NcluSLwAPGxDz/CSrkqxNcnV7rpv3bN+8Hde1Se5I8rMkR7XbNk3y4Xa/25P8OMmxPePePcldvX8+PX+e35vKOUgan/nL/CWNKvOX+UvrsijWXPo6sAWwuF3eA1gD7NkTswewIskC4GLgCuApwF7AA4HPjyWiAY4BlgGHAE+j+fv+kgFxLwV+B+wG/A3wOuDAdtt+wC+BtwIL2jahJE8FlgNnADsB/9Hu3xvzHODjwPuBHYCDgQOAE3rCPgq8AjgaeALwKuA37bZNgF8BL2q3/R3wRuCVAFV1MfDTdv+xY27SLn94snOQNCnzl/lLGlXmL/OX+lWVzTZnDbgMOK79/HHgzcDtNMnvAcCdwNNpktpFffv+L6CAp7TLbwGu6Nl+PfCGnuUAPwRW9KxbAVza1++FwId6lq8BjhninM4CLuxb96Hmf7d7ly8G/r4v5oU0tx8FeGx7bnsPcdy3A1/uWT4GuLJn+bnAHcBD5vq/u802H5r56z4x5i+bbYSa+es+MeYvm1eKNedW0HwbCc0kDecCl7frng7c1S7vAuye5NaxBvyi3e/R/Z0m+WNg63ZfoM2I8J8DxtB/O8t1wFbrdTaNJwCX9q3rX94F+Lu+8zkL+MN23DsD9wBfHe8gSQ5LsjLJTe3+/xtY2BPyUeBRSXZrlw8G/r2qbl7fE5N0Hyswf5m/pNG0AvOX+Uv3cqItzbUVwGuSbA/8EbCqXbcncBNwSVXd1d568kWab9/6rZ6g/6k8g3LXgH025AujTCFmE+AfgE8P2HbTZH0kORB4D82fxyXA/wCvAfYdi6mqm5J8Hjg4yVXAC4DnT2FskqZmBeavfuYvaTSswPzVz/zVYRbFmmtfB/4AOBb4RlXd3U6ucAZwI/ClNu7bNM9vXFtV/Ul0HVX130luoHn+5asASQI8GbhhyDHeCWw6RPwPaJ6h6dW//G3g8VX1k0EdJPk2TeLeEzhvQMgzgMuq6v09+6zzjS1wJvAZ4Gc0v7y+PJUTkDQl5q8BzF/SSDB/DWD+6i5vn9acqmYa/28DL+P3t6pcCjwCeCrNt5YA/wT8MXB2kqcmeVSSvZKckeSPxun+VODYJPsmeRxwMs2zMsO+xP0aYEmSbZJsOYX49wJ7JTkuzYyOr6bnG8TWW4GXJHlrkicmeXyaF8y/E6Cqfgx8CvhQkv2TbJdkSZKXt/v/CPizNDNIPjbJ3zP4HYEXAjfTPCv0z1V1z5DnLmkc5i/zlzSqzF/mL92XRbE2Bl+l+SZwBUBVrQW+RTMpweXtuutonnG5h+abu+/TJOo72jbIScC/AP/c9gfwb8DaIcf3JppfEj+lubVmQlX1LZqZCg+neV5mP5pJKHpjzgeeR/NN5OVtewPw856wV9A85/JemgkqltP8YgL4IE3SPovmOZ1FNL90+sdSNOd/v/anpOll/jJ/SaPK/GX+UivNfzOpG9rbYr5ZVUfO9VhmS5r3Dj6mqpbO9VgkrT/zl6RRZf7Sxs5nijVvJXkk8BzgazR/1w8F/rT9Oe+lmQFyF5pvPF80x8ORNATzl/lLGlXmL/PXKPL2ac1n99AkpMtpbt95GvDcqlq5IZ0mOTc9U/n3tTdOw7iny+eALwAfqaovzvVgJA3F/GX+kkaV+cv8NXK8fVoaUpJtgPuPs/mWqrplNscjSVNl/pI0qsxfmkkWxZIkSZKkzvL2aUmSJElSZ1kUS5IkSZI6y6JYkiRJktRZFsWSJEmSpM6yKJYkSZIkddb/B2IqxiQYr/SRAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
    " + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "g = sns.FacetGrid(best_model, col=\"epoch\", sharey=True, aspect=1.5, margin_titles=True)\n", + "g.map(sns.barplot, \"weight_decay\", \"loss_mean\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Increase the width factor to find the better model\n", + "Simple Wide Residual Netwok, witdh factor (k) = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "data_wide_basic = pd.read_csv(\"/Users/sefika/adversarial_examples_parseval_net/src/data/GridCV/grid_16_2.csv\", sep=\";\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "del data_wide_basic[data_wide_basic.columns[0]]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3widing factor...acc9epoch_stoppedloss10loss4loss5loss6loss7loss8loss9reg_penalty
    00.90.164.01.2118410.6684121.1873980.6963351.0484990.7190232.0...0.68062850.01.080411.0045011.228660.998551.0898620.9413141.0945910.0001
    \n", + "

    1 rows × 26 columns

    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate batch size loss1 acc1 loss2 \\\n", + "0 0.9 0.1 64.0 1.211841 0.668412 1.187398 \n", + "\n", + " acc2 loss3 acc3 widing factor ... acc9 epoch_stopped \\\n", + "0 0.696335 1.048499 0.719023 2.0 ... 0.680628 50.0 \n", + "\n", + " loss10 loss4 loss5 loss6 loss7 loss8 loss9 \\\n", + "0 1.08041 1.004501 1.22866 0.99855 1.089862 0.941314 1.094591 \n", + "\n", + " reg_penalty \n", + "0 0.0001 \n", + "\n", + "[1 rows x 26 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_wide_basic.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 1 entries, 0 to 0\n", + "Data columns (total 26 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 momentum 1 non-null float64\n", + " 1 learning rate 1 non-null float64\n", + " 2 batch size 1 non-null float64\n", + " 3 loss1 1 non-null float64\n", + " 4 acc1 1 non-null float64\n", + " 5 loss2 1 non-null float64\n", + " 6 acc2 1 non-null float64\n", + " 7 loss3 1 non-null float64\n", + " 8 acc3 1 non-null float64\n", + " 9 widing factor 1 non-null float64\n", + " 10 acc10 1 non-null float64\n", + " 11 acc4 1 non-null float64\n", + " 12 acc5 1 non-null float64\n", + " 13 acc6 1 non-null float64\n", + " 14 acc7 1 non-null float64\n", + " 15 acc8 1 non-null float64\n", + " 16 acc9 1 non-null float64\n", + " 17 epoch_stopped 1 non-null float64\n", + " 18 loss10 1 non-null float64\n", + " 19 loss4 1 non-null float64\n", + " 20 loss5 1 non-null float64\n", + " 21 loss6 1 non-null float64\n", + " 22 loss7 1 non-null float64\n", + " 23 loss8 1 non-null float64\n", + " 24 loss9 1 non-null float64\n", + " 25 reg_penalty 1 non-null float64\n", + "dtypes: float64(26)\n", + "memory usage: 336.0 bytes\n" + ] + } + ], + "source": [ + "data_wide_basic.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "data_wide_basic[\"loss_mean\"] = (data_wide_basic[\"loss1\"]+data_wide_basic[\"loss2\"]+data_wide_basic[\"loss3\"]+data_wide_basic[\"loss4\"]+data_wide_basic[\"loss5\"]+data_wide_basic[\"loss6\"]+data_wide_basic[\"loss7\"]+data_wide_basic[\"loss8\"]+data_wide_basic[\"loss9\"]+data_wide_basic[\"loss10\"])/10" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "data_wide_basic[\"acc_mean\"] = (data_wide_basic[\"acc1\"]+data_wide_basic[\"acc2\"]+data_wide_basic[\"acc3\"]+data_wide_basic[\"acc4\"]+data_wide_basic[\"acc5\"]+data_wide_basic[\"acc6\"]+data_wide_basic[\"acc7\"]+data_wide_basic[\"acc8\"]+data_wide_basic[\"acc9\"]+data_wide_basic[\"acc10\"])/10" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "data_wide_basic['epoch'] = data_wide_basic['epoch_stopped']\n", + "data_wide_basic['weight_decay'] = data_wide_basic['reg_penalty']" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning rateepochbatch sizeweight_decayloss_meanacc_mean
    00.90.150.064.00.00011.0885620.706981
    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate epoch batch size weight_decay loss_mean \\\n", + "0 0.9 0.1 50.0 64.0 0.0001 1.088562 \n", + "\n", + " acc_mean \n", + "0 0.706981 " + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_wide_basic.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Wide Residual Network\n", + "Wide Residual Netwok, witdh factor (k) = 4" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "data_wide_k_4 = pd.read_csv(\"/Users/sefika/adversarial_examples_parseval_net/src/data/GridCV/grid_16_4.csv\", sep=\";\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    Unnamed: 0momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3...acc9epoch_stoppedloss10loss4loss5loss6loss7loss8loss9reg_penalty
    000.90.164.01.2735120.7260031.2092180.7504361.5702950.682373...0.76439850.01.2091791.4465821.1824341.0725321.2439281.2480491.0830730.0001
    \n", + "

    1 rows × 27 columns

    \n", + "
    " + ], + "text/plain": [ + " Unnamed: 0 momentum learning rate batch size loss1 acc1 \\\n", + "0 0 0.9 0.1 64.0 1.273512 0.726003 \n", + "\n", + " loss2 acc2 loss3 acc3 ... acc9 epoch_stopped \\\n", + "0 1.209218 0.750436 1.570295 0.682373 ... 0.764398 50.0 \n", + "\n", + " loss10 loss4 loss5 loss6 loss7 loss8 loss9 \\\n", + "0 1.209179 1.446582 1.182434 1.072532 1.243928 1.248049 1.083073 \n", + "\n", + " reg_penalty \n", + "0 0.0001 \n", + "\n", + "[1 rows x 27 columns]" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_wide_k_4.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "del data_wide_k_4[data_wide_k_4.columns[0]]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning ratebatch sizeloss1acc1loss2acc2loss3acc3widing factor...acc9epoch_stoppedloss10loss4loss5loss6loss7loss8loss9reg_penalty
    00.90.164.01.2735120.7260031.2092180.7504361.5702950.6823734.0...0.76439850.01.2091791.4465821.1824341.0725321.2439281.2480491.0830730.0001
    \n", + "

    1 rows × 26 columns

    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate batch size loss1 acc1 loss2 \\\n", + "0 0.9 0.1 64.0 1.273512 0.726003 1.209218 \n", + "\n", + " acc2 loss3 acc3 widing factor ... acc9 epoch_stopped \\\n", + "0 0.750436 1.570295 0.682373 4.0 ... 0.764398 50.0 \n", + "\n", + " loss10 loss4 loss5 loss6 loss7 loss8 loss9 \\\n", + "0 1.209179 1.446582 1.182434 1.072532 1.243928 1.248049 1.083073 \n", + "\n", + " reg_penalty \n", + "0 0.0001 \n", + "\n", + "[1 rows x 26 columns]" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_wide_k_4.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    momentumlearning rateepochbatch sizeweight_decayloss_meanacc_mean
    00.90.150.064.00.00011.253880.734729
    \n", + "
    " + ], + "text/plain": [ + " momentum learning rate epoch batch size weight_decay loss_mean \\\n", + "0 0.9 0.1 50.0 64.0 0.0001 1.25388 \n", + "\n", + " acc_mean \n", + "0 0.734729 " + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_wide_k_4[\"acc_mean\"] = (data_wide_k_4[\"acc1\"]+data_wide_k_4[\"acc2\"]+data_wide_k_4[\"acc3\"]+data_wide_k_4[\"acc4\"]+data_wide_k_4[\"acc5\"]+data_wide_k_4[\"acc6\"]+data_wide_k_4[\"acc7\"]+data_wide_k_4[\"acc8\"]+data_wide_k_4[\"acc9\"]+data_wide_k_4[\"acc10\"])/10\n", + "data_wide_k_4[\"loss_mean\"] = (data_wide_k_4[\"loss1\"]+data_wide_k_4[\"loss2\"]+data_wide_k_4[\"loss3\"]+data_wide_k_4[\"loss4\"]+data_wide_k_4[\"loss5\"]+data_wide_k_4[\"loss6\"]+data_wide_k_4[\"loss7\"]+data_wide_k_4[\"loss8\"]+data_wide_k_4[\"loss9\"]+data_wide_k_4[\"loss10\"])/10\n", + "\n", + "data_wide_k_4['epoch'] = data_wide_k_4['epoch_stopped']\n", + "data_wide_k_4['weight_decay'] = data_wide_k_4['reg_penalty']\n", + "data_wide_k_4.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualization " + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "column_list = [\"loss_mean\", \"acc_mean\", \"widing factor\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "data_k_4 = data_wide_k_4.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(1)\n", + "data_k_2 = data_wide_basic.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(1)\n", + "data_k_1 = data.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "best_model=pd.DataFrame()\n", + "best_model = pd.concat([data_k_1, data_k_2, data_k_4], ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\begin{tabular}{lrrr}\n", + "\\toprule\n", + "{} & loss\\_mean & acc\\_mean & widing factor \\\\\n", + "\\midrule\n", + "0 & 0.874353 & 0.707679 & 1.0 \\\\\n", + "1 & 1.088562 & 0.706981 & 2.0 \\\\\n", + "2 & 1.253880 & 0.734729 & 4.0 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + } + ], + "source": [ + "print(best_model.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3).to_latex())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + "
    Sefika Efeoglu
    \n", + "
    Universiteat Potsdam
    \n", + "
    \n", + "
    " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}