The Algorithms logo
算法
关于我们捐赠

CNN Pytorch

H
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wq_TKqjUmILg"
   },
   "source": [
    "# Convolutional Neural Network (CNN)\n",
    "\n",
    "## Resources\n",
    "\n",
    "    CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network\n",
    "    Pytorch : https://pytorch.ac.cn/tutorials/beginner/basics/intro.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start by importing both the training and testing MNIST datasets using DataLoaders and the torchvision provided datasets. You can set both the training and testing batch size to be whatever you feel is best."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "qGEvJYHnmILh"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://127.0.0.1:8080/",
     "height": 437,
     "referenced_widgets": [
      "f1ae475d1e48411ab7cb9f49f6f673c9",
      "6e770d6f5e5f4101b2c985f8182eb79e",
      "d784c3bdbaf543a299447b17b500c2a8",
      "51567646aecc4126a3b9cb96f97d5be5",
      "f65a4ec103c34c7186fb76d3e3507795",
      "af12215b0ba24c699e1111c132dee3fe",
      "908f60d17e1b49869ac3d4ee53a74a6e",
      "f6af65953a9841e1907a47fb27c01b96",
      "44f9d2dde4424058a337a1e3d585b5fe",
      "037faad73a8a40499555406f4e586731",
      "4b80b4415ce34130a048933b2df179a5",
      "c52318a206de4565a725388bc50bdfc0",
      "d054c2e9330c433ab6bb900c6fa7dac6",
      "7c8cf011ed684ef0841d41da76c9bfc8",
      "14d009d293864f9f85bd16b5e1b6b381",
      "553cf4ae8edb4b48a58ab4446036ae87",
      "a1dd0d1a7fb14235a271992bea0d233c",
      "b9fa1995789a49fdb6a0c3e8505733de",
      "d2e31ba6e42448839959f523cc56dbcb",
      "5ecbc7b6eb6544709da61283dfc8d3c6",
      "c75452bfc2264392bf3d842bfbd4eeee",
      "0554687d97424798a86ea0a4c56cdbf8",
      "2ee0648f054c49049fdf5d6ac81ec086",
      "4be171eff56046c2a401bf02b6d704c9",
      "c8c6df38201b470bb06bffc677873f93",
      "f51f06fbe28746dea09e075256e29451",
      "7c90cc0c4c3b4cd89dbdfffdb78d463d",
      "9ee4f57a58c14bb1880789196ff7ef63",
      "74d1d431fbb84c4f948ca510e397f4e2",
      "90f550d2f7344eb894208d0373ac6f8d",
      "7837b9def9dd470cb57f96b38afbfc18",
      "b3105726704240439adcf7f13bd48cca",
      "54ce0f87fc1f46d2b9f9850a393b46cf",
      "1a89067cd68c41fb96b74e0ebf3d1931",
      "91798827c9254ffc969228862c8ee37f",
      "cdbc8de66e8c4fa79100dcac1fcfed4d",
      "2773f678ca4a45158df219d2672fa646",
      "9f7585b1023d4e20ba3649efcfcdb881",
      "3d3af4241b714b9a87253ac87b2e31b1",
      "84bda6e8c38b4b669e2825c962e9dbfb",
      "780790f42e764ec5a0b8d441a290e5d3",
      "47b04fa0a67c4fbbbcdc3d1062430259",
      "480e369fb79f4326bd13d2355fddd890",
      "898cdc388a2d428ca742a9ba2365df20"
     ]
    },
    "id": "N_61-p6ymILj",
    "outputId": "765b9e09-2ce2-47ab-eac2-6978fa820170"
   },
   "outputs": [],
   "source": [
    "# Downloading MNIST dataset from Pytorch\n",
    "dataset = datasets.MNIST(\n",
    "    root=\"./data\",\n",
    "    download=True,\n",
    "    train=True,\n",
    "    transform=transforms.ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "UlRAhZOpmILk"
   },
   "outputs": [],
   "source": [
    "# Splitting the dataset into training and testing set\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://127.0.0.1:8080/",
     "height": 332
    },
    "id": "iQWzZCRFmILl",
    "outputId": "d2992c28-d667-4653-bbc2-9b323c82def0"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets\n",
      "  warnings.warn(\"train_labels has been renamed targets\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Label : 3')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQCklEQVR4nO3dfbBU9X3H8fcHuII8OIIEgygi1PhQTUi8UUdpa2Jj1WnUTEYb2jS0NWKTaOuEOLW2HR3bTh1TJbHVGKhEND5OwZEkTBolbY1jtF6VCD6hUowIgoagQCry8O0fe2gvePfsZc/ZPcv9fV4zO7t3v+fhy3I/9+zub8/+FBGY2cA3qOoGzKw9HHazRDjsZolw2M0S4bCbJcJhN0uEw54gSf8h6YvtXteq5bDvwyStkvTbVffRX5Iuk7RS0juS1kiaLWlI1X2lwmG3dvoe8LGIOAA4DvgI8GfVtpQOh30AkjRa0vclvSnpl9ntQ/dYbIqk/5L0tqQHJI3ptf7Jkh6VtFHSzySdVkZfEfFKRGzctRtgJ/BrZWzbGnPYB6ZBwHeAw4GJwP8A/7zHMl8A/gQ4BNgO3AggaQLwA+DvgDHA14AFkj7QaKeSpkna2GCZ35f0DvAWtSP7t/v9r7JCHPYBKCJ+ERELIuJXEbEJ+Hvgt/ZY7I6IWB4RW4C/AS6QNBj4PLA4IhZHxM6IeBDoAc7ux34fiYgDGyxzV/Y0/kPALcC6vf4HWlMc9gFI0nBJ35b0anYUfRg4MAvzLq/1uv0q0AWMpfZs4PzsKfzG7Eg9DRhfZo8R8RLwLHBzmdu1+vxO6MA0CzgKOCki3pA0FXia2uvkXQ7rdXsisI3aU+vXqB31L2pDn0OAKW3Yj+Ej+0DQJWlYr8sQYBS11+kbszferupjvc9LOlbScOAa4F8jYgfwXeDTkn5H0uBsm6f18QbfXpP0RUnjstvHAn8JLCm6Xesfh33ft5hasHddrga+AexP7Uj9GPDDPta7A7gNeAMYRjYEFhGvAecCVwJvUjvSX04/flck/YakzTmLnAosk7Ql63txth9rA/nLK8zS4CO7WSIcdrNEOOxmiXDYzRLR1nH2/TQ0hjGinbs0S8q7bOG92Kq+aoXCLulM4JvAYOBfIuLavOWHMYKTdHqRXZpZjsej/scWmn4an3308ibgLOBYYHr2QQkz60BFXrOfCLwcESsj4j3gHmofxjCzDlQk7BPY/WSK1dl9u5E0U1KPpJ5tbC2wOzMrokjY+3oT4H0fx4uIORHRHRHdXQwtsDszK6JI2Fez+5lThwJrirVjZq1SJOxPAEdKOkLSfsDngEXltGVmZWt66C0itku6BPg3akNv8yLi2dI6M7NSFRpnj4hdpymaWYfzx2XNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRbZ2y2Vrk5A/XLf33OflTZF/12fty6zesyJ91d9Oyg3LreaZc83Rufee77za9bXs/H9nNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0R4nH0f8PoVp+TWF3/5urq1iUNGFtr3H5yQPw7PCc1ve9qTF+fWRyx4vPmN2/sUCrukVcAmYAewPSK6y2jKzMpXxpH9ExHxVgnbMbMW8mt2s0QUDXsAP5L0pKSZfS0gaaakHkk929hacHdm1qyiT+NPjYg1ksYBD0p6ISIe7r1ARMwB5gAcoDFRcH9m1qRCR/aIWJNdrwfuB04soykzK1/TYZc0QtKoXbeBM4DlZTVmZuUq8jT+YOB+Sbu2c1dE/LCUrmw3h89fmVtfM3P/urWJHfxJirnXz86tXzjkq7n1Ufc+VmY7A17TvwoRsRL4SIm9mFkLeejNLBEOu1kiHHazRDjsZolw2M0S0cEDM7bL9rVv5NYvnHtp3dpDX6p/+ivA+AanwC7aMjy3fs6IX+XW8xyzX/62135qe2591L1N7zpJPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwOPsAcOg/PFq39p3p+d/1fOXYF3PrL2/9YP7OR+SfflvE0Tduzq3vbNmeByYf2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRHicfYBb+E+fzK3vvFS59b8e+0KZ7eyVncO6Ktv3QOQju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+zD3AHzf1pbv2nDx2VW//697bl1i8f88pe99Rfm6/ZklsfeWbLdj0gNTyyS5onab2k5b3uGyPpQUkvZdejW9ummRXVn6fxtwF7/g29AlgSEUcCS7KfzayDNQx7RDwMbNjj7nOB+dnt+cB55bZlZmVr9g26gyNiLUB2Pa7egpJmSuqR1LONrU3uzsyKavm78RExJyK6I6K7i6Gt3p2Z1dFs2NdJGg+QXa8vryUza4Vmw74ImJHdngE8UE47ZtYqDcfZJd0NnAaMlbQauAq4FrhP0oXAz4HzW9mkNW/9Jafk1jcelz8H+qLR9zfYQ+teCW54LP8760fSuu+sH4gahj0iptcpnV5yL2bWQv64rFkiHHazRDjsZolw2M0S4bCbJcKnuO4D9PHjc+vnzf9x3doXDvhG7rrDB+3XYO/VHQ8mLdzzlIzdecrmveMju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+z7wN+cfzI3PrvjXqpbm34oOFlt9M2L87K7/3IGbll24OP7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIjzOvg8YMy9/2uVTDv1a3dpPLvp67rpjB49oqqd2GH/wxqpbGFB8ZDdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuFx9gFg4jWP1q19+uVZueu+e2Cxv/fR4Ddowazr6tamdOWfp2/lavg/LWmepPWSlve672pJr0taml3Obm2bZlZUf/6s3wac2cf9syNianZZXG5bZla2hmGPiIeB/Hl4zKzjFXnBdomkZ7Kn+aPrLSRppqQeST3b2Fpgd2ZWRLNh/xYwBZgKrAWur7dgRMyJiO6I6O5iaJO7M7Oimgp7RKyLiB0RsROYC5xYbltmVramwi5pfK8fPwMsr7esmXWGhuPsku4GTgPGSloNXAWcJmkqEMAq4OLWtWhFHHDXY/n1ojuQcstnTK5/rv0rF9ySu+6Xj/jP3Pqdx56eW9/x3Ircemoahj0ipvdx960t6MXMWsgflzVLhMNulgiH3SwRDrtZIhx2s0T4FFcrZND+++fWGw2v5dm0Y1j+Att3NL3tFPnIbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwuPsVsgLs3+9wRL1v+a6kdkLz8mtT1qRP5W17c5HdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sER5n76chEw6pW3vv9sG567618LDc+ribmh+LbrUhkyfl1h86c3aDLTQ/LfPk+36ZW9/Z9JbT5CO7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpaI/kzZfBhwO/BBakObcyLim5LGAPcCk6hN23xBROQPjO7D1txcf3Ljp4+5J3fdOZfUH6MH+O7rv5tbH7Fqc25959Ln6ta2f/KE3HU3HD00t/7ZP/1xbn1KV/Pj6Ed8/6Lc+tGv1P932d7rz5F9OzArIo4BTga+IulY4ApgSUQcCSzJfjazDtUw7BGxNiKeym5vAp4HJgDnAvOzxeYD57WoRzMrwV69Zpc0Cfgo8DhwcESshdofBGBc6d2ZWWn6HXZJI4EFwGUR8c5erDdTUo+knm1sbaZHMytBv8IuqYta0O+MiIXZ3eskjc/q44H1fa0bEXMiojsiurvIfzPIzFqnYdglCbgVeD4ibuhVWgTMyG7PAB4ovz0zK4siIn8BaRrwE2AZ/39W4ZXUXrffB0wEfg6cHxEb8rZ1gMbESTq9aM+V2HrWx+vWPvy3S3PXvfGQJwrte8Hm+sN+ALe+Pq1u7abJ9+Wue0SBoTOAHZF/ouktbx9et/aDUybnb3vj2031lLLHYwnvxAb1VWs4zh4RjwB9rgzsm8k1S5A/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S0XCcvUz78jh7nhVz64/BAwxf2ZVbf/bSm8tsp62eee/d3Prlk05uUycG+ePsPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwlM0l+NBF+eerDxo+PLd+1MgvFdr/iOPrf43AU933Ftr2im1bcutf/eNLc+uDearQ/q08PrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonw+exmA4jPZzczh90sFQ67WSIcdrNEOOxmiXDYzRLhsJslomHYJR0m6d8lPS/pWUl/nt1/taTXJS3NLme3vl0za1Z/vrxiOzArIp6SNAp4UtKDWW12RPxj69ozs7I0DHtErAXWZrc3SXoemNDqxsysXHv1ml3SJOCjwOPZXZdIekbSPEmj66wzU1KPpJ5tbC3WrZk1rd9hlzQSWABcFhHvAN8CpgBTqR35r+9rvYiYExHdEdHdxdDiHZtZU/oVdkld1IJ+Z0QsBIiIdRGxIyJ2AnOBE1vXppkV1Z934wXcCjwfETf0un98r8U+Aywvvz0zK0t/3o0/FfhDYJmkpdl9VwLTJU0FAlgFXNyC/sysJP15N/4RoK/zYxeX346ZtYo/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S0dYpmyW9Cbza666xwFtta2DvdGpvndoXuLdmldnb4RHxgb4KbQ37+3Yu9UREd2UN5OjU3jq1L3BvzWpXb34ab5YIh90sEVWHfU7F+8/Tqb11al/g3prVlt4qfc1uZu1T9ZHdzNrEYTdLRCVhl3SmpBclvSzpiip6qEfSKknLsmmoeyruZZ6k9ZKW97pvjKQHJb2UXfc5x15FvXXENN4504xX+thVPf1521+zSxoMrAA+BawGngCmR8RzbW2kDkmrgO6IqPwDGJJ+E9gM3B4Rx2X3XQdsiIhrsz+UoyPiLzqkt6uBzVVP453NVjS+9zTjwHnAH1HhY5fT1wW04XGr4sh+IvByRKyMiPeAe4BzK+ij40XEw8CGPe4+F5if3Z5P7Zel7er01hEiYm1EPJXd3gTsmma80scup6+2qCLsE4DXev28ms6a7z2AH0l6UtLMqpvpw8ERsRZqvzzAuIr72VPDabzbaY9pxjvmsWtm+vOiqgh7X1NJddL436kR8THgLOAr2dNV659+TePdLn1MM94Rmp3+vKgqwr4aOKzXz4cCayroo08RsSa7Xg/cT+dNRb1u1wy62fX6ivv5P500jXdf04zTAY9dldOfVxH2J4AjJR0haT/gc8CiCvp4H0kjsjdOkDQCOIPOm4p6ETAjuz0DeKDCXnbTKdN415tmnIofu8qnP4+Itl+As6m9I/8K8FdV9FCnr8nAz7LLs1X3BtxN7WndNmrPiC4EDgKWAC9l12M6qLc7gGXAM9SCNb6i3qZRe2n4DLA0u5xd9WOX01dbHjd/XNYsEf4EnVkiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiP8Fvji1zrt7lZQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Visualizing a sample from dataset\n",
    "plt.imshow(train_dataset.dataset.data[10])\n",
    "plt.title(\"Label : \" + str(train_dataset.dataset.train_labels[10].item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "9WhgIZF0mILl"
   },
   "outputs": [],
   "source": [
    "# Creating a DataLoader for training and testing\n",
    "train = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test = DataLoader(test_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pWhZXmm6mILm"
   },
   "source": [
    "Define a network with the following architecture:\n",
    "\n",
    "Conv2d (input channels=1, output channels = 15,kernel size = 5)\n",
    "$\\rightarrow$\n",
    "MaxPool (kernel size = 2)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Conv2d (input channels=15, output channels = 30,kernel size = 5)\n",
    "$\\rightarrow$\n",
    "Dropout2d (p = 0.5)\n",
    "$\\rightarrow$\n",
    "MaxPool (kernel size = 2)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Linear(input dimension = 480, hidden units = 64)\n",
    "$\\rightarrow$\n",
    "ReLU\n",
    "$\\rightarrow$\n",
    "Dropout (p=0.5)\n",
    "$\\rightarrow$\n",
    "Linear(input dimension = 64, hidden units = 10)\n",
    "$\\rightarrow$\n",
    "LogSoftMax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "_-RTFbeKmILn"
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.cnn = nn.Sequential(\n",
    "            nn.Conv2d(1, 15, kernel_size=5),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(15, 30, kernel_size=5),\n",
    "            nn.Dropout2d(0.5),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(480, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(64, 10),\n",
    "            nn.LogSoftmax(dim=1),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.cnn(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jzwFIYidmILo"
   },
   "source": [
    "Train the network you defined in the previous question on MNIST, using the optimizer and the number of training epochs you deem appropriate. Use a cross-entropy loss. Each epoch test your model on the testing dataset and print the value of the accuracy that you achieve. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://127.0.0.1:8080/"
    },
    "id": "-hdIBNSzmILp",
    "outputId": "101815c2-c1c6-441b-bd05-ce08b702c044"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891\n",
      "Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743\n",
      "Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391\n",
      "Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097\n",
      "Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349\n",
      "Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978\n",
      "Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738\n",
      "Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843\n",
      "Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386\n",
      "Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "\n",
    "model = CNN()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Running the model on GPU if available\n",
    "## Refer pytorch documentation for more details about copying model and data onto the device\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "cost = []\n",
    "epochs = 10\n",
    "\n",
    "# Training the model\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    loss_epoch = []\n",
    "    train_acc = []\n",
    "\n",
    "    for x, y in train:\n",
    "\n",
    "        # Predicting the output\n",
    "        y_pred = model(x.to(device))\n",
    "\n",
    "        # Converting the predicted output from one hot encoding to a single number\n",
    "        _, t_preds = torch.max(y_pred, dim=1)\n",
    "\n",
    "        # Calculating the training accuracy\n",
    "        train_acc.append(\n",
    "            torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))\n",
    "        )\n",
    "\n",
    "        # Calculating the loss\n",
    "        loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))\n",
    "\n",
    "        # Backpropagation\n",
    "\n",
    "        # Zeroing the gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Calculating the gradients\n",
    "        loss.backward()\n",
    "\n",
    "        # Updating the weights\n",
    "        optimizer.step()\n",
    "\n",
    "        # Appending the loss of each batch to the epoch loss\n",
    "        loss_epoch.append(loss.item())\n",
    "\n",
    "    # Calculating test accuracy\n",
    "    with torch.no_grad():\n",
    "        if epoch % 1 == 0:\n",
    "            test_acc = []\n",
    "            for x, y in test:\n",
    "                y_pred = model(x.to(device))\n",
    "                _, t_preds = torch.max(y_pred, dim=1)\n",
    "                test_acc.append(\n",
    "                    torch.tensor(\n",
    "                        torch.sum(t_preds == y.to(device)).item() / len(t_preds)\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            print(\n",
    "                \"Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}\".format(\n",
    "                    epoch,\n",
    "                    sum(loss_epoch) / len(loss_epoch),\n",
    "                    sum(train_acc) / len(train_acc),\n",
    "                    sum(test_acc) / len(test_acc),\n",
    "                )\n",
    "            )\n",
    "\n",
    "    cost.append(sum(loss_epoch) / len(loss_epoch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://127.0.0.1:8080/",
     "height": 337
    },
    "id": "J7ecIEspmILq",
    "outputId": "5ec70ca4-45b0-4cf3-fedc-4f57daf7a8fa"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fa94052aa00>]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAEvCAYAAAAErSPcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxBUlEQVR4nO3deXCc933n+c+3LwAN4gYIgAd4AeIlS7IEU6clkbREKfFG8sQztuKJZ2M7smxd2ZqtjXcytbW1Lk/NTKVmLduyZUWWEydeaxzZcuSTik5akmUR1GVSFCmS4gGSIEEQIG709ds/utFsgk2iQQJ4uhvvVxUL/TzP72l8uwqi+MH3+f1+5pwTAAAAACD/+bwuAAAAAACQGwIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFIiA1wVkU19f75YuXep1GQAAAADgiW3btp1wzjVMPJ+XAW7p0qXq6OjwugwAAAAA8ISZHch2nkcoAQAAAKBAEOAAAAAAoEAQ4AAAAACgQBDgAAAAAKBA5BTgzOw2M9tlZnvM7CvnGfcRM4ub2Sczzu03sz+Y2VtmxsokAAAAAHCBJl2F0sz8kh6WdIukTklbzexp59y7Wcb9N0mbs7zNeufciWmoFwAAAADmrFw6cOsk7XHO7XPORSQ9IemOLOPul/QTScensT4AAAAAQEouAW6hpEMZx52pc2lmtlDSJyQ9kuV+J+kZM9tmZndfaKEAAAAAMNflspG3ZTnnJhx/XdJfO+fiZmcNv945d8TM5kv6VzN7zzm35axvkgx3d0tSS0tLDmUBAAAAwNySSweuU9LijONFko5MGNMu6Qkz2y/pk5K+bWZ3SpJz7kjq63FJTyn5SOZZnHOPOufanXPtDQ0NU/kMM240Gtc/vLpfPYNjXpcCAAAAYA7LJcBtldRmZsvMLCTp05KezhzgnFvmnFvqnFsq6UlJX3bO/czMys2sQpLMrFzSrZK2T+snmAWdvSP6v3++Q4+9/IHXpQAAAACYwyYNcM65mKT7lFxdcqekHzvndpjZPWZ2zyS3N0p62czelvS6pF86535zsUXPttb58/TxyxboB6/uV+9QxOtyAAAAAMxR5tzE6Wzea29vdx0d+bVl3O5jA7r1/92i+ze06j/eutLrcgAAAAAUMTPb5pxrn3g+p428IV3SWKE/+lCT/v6V/To1HPW6HAAAAABzEAFuCu7f0KaBsZgef4W5cAAAAABmHwFuClY3V2rT2kY9/soH6h+lCwcAAABgdhHgpuj+DW0aGI3p71/Z73UpAAAAAOYYAtwUXbqwSh9b3ajvvfyBBujCAQAAAJhFBLgL8MDGVp0aieoHvzvgdSkAAAAA5hAC3AW4bFG11q9s0GO/3aehsZjX5QAAAACYIwhwF+iBjW3qHY7qH1+jCwcAAABgdhDgLtCHW2p04yUN+rst+zQcoQsHAAAAYOYR4C7Cgxtb1TMU0Q9fO+h1KQAAAADmAALcRbhqSa1uaK3Xd7fs00gk7nU5AAAAAIocAe4iPbCxTScGx/Sj1+nCAQAAAJhZBLiLtG5Zra5ZXqtHXtqr0ShdOAAAAAAzhwA3DR7ceImOD4zpf2495HUpAAAAAIoYAW4aXLO8VuuW1uo7L+7VWIwuHAAAAICZQYCbBmamBza2qat/VP/c0el1OQAAAACKFAFumlzfWqerltToOy/uVSSW8LocAAAAAEWIADdNxrtwh/tG9JM36MIBAAAAmH4EuGl0Y1u9Ll9crYdf2KNonC4cAAAAgOlFgJtGZqYHN7aqs3dET71x2OtyAAAAABQZAtw0W79yvj60sErfemGPYnThAAAAAEwjAtw0G58Ld/DksP7lrSNelwMAAACgiBDgZsDHVs/XmuZKunAAAAAAphUBbgaMd+E+ODGkX7xz1OtyAAAAABQJAtwMuXVNo1Y1Veibz7+veMJ5XQ4AAACAIkCAmyE+n+n+DW3a2z2kX/6BLhwAAACAi5dTgDOz28xsl5ntMbOvnGfcR8wsbmafnOq9xej2S5vUNn+evvnc+0rQhQMAAABwkSYNcGbml/SwpNslrZF0l5mtOce4/yZp81TvLVY+n+n+jW16//igfrOjy+tyAAAAABS4XDpw6yTtcc7tc85FJD0h6Y4s4+6X9BNJxy/g3qL1xx9q1oqGcn2DLhwAAACAi5RLgFso6VDGcWfqXJqZLZT0CUmPTPXeYudPzYV7r2tAz7x7zOtyAAAAABSwXAKcZTk3sZX0dUl/7ZyLX8C9yYFmd5tZh5l1dHd351BW4fj4Zc1aVp/swjlHFw4AAADAhcklwHVKWpxxvEjSkQlj2iU9YWb7JX1S0rfN7M4c75UkOecedc61O+faGxoacqu+QAT8Pt27vlXvHu3XszuPT34DAAAAAGSRS4DbKqnNzJaZWUjSpyU9nTnAObfMObfUObdU0pOSvuyc+1ku984Vd16xQC21YbpwAAAAAC7YpAHOOReTdJ+Sq0vulPRj59wOM7vHzO65kHsvvuzCE/D7dN/6Vv3h8Cm9uKu4HhEFAAAAMDssH7tB7e3trqOjw+sypl00ntD6v31R9fNK9NSXr5NZtimCAAAAAOY6M9vmnGufeD6njbwxPYKpuXBvHerTlvdPeF0OAAAAgAJDgJtlf3rlIi2oKtVDz+5mLhwAAACAKSHAzbJQwKcvrW/VGwf79OreHq/LAQAAAFBACHAe+Hfti9RUWaqHnmVFSgAAAAC5I8B5oCTg15duXqHX95/Ua/tOel0OAAAAgAJBgPPIpz6yWPMrSvSN5973uhQAAAAABYIA55HSoF9fvGmFfrevR69/QBcOAAAAwOQIcB76s3Utqp9HFw4AAABAbghwHioL+fXFG5fr5T0ntO0AXTgAAAAA50eA89hnrmlRXXlIDz23x+tSAAAAAOQ5ApzHwqGA/vLG5dqyu1tvHuz1uhwAAAAAeYwAlwf+/JolqgkH9c3n6cIBAAAAODcCXB4oLwnoCx9druffO653Ovu8LgcAAABAniLA5YnPXrtEVWVBfYO5cAAAAADOgQCXJypKg/r8Dcv07M5j2n74lNflAAAAAMhDBLg88h+uW6qK0oC++Tz7wgEAAAA4GwEuj1SVBfW565dp845j2nm03+tyAAAAAOQZAlye+dz1yzSvJKBvsSIlAAAAgAkIcHmmKhzU/3rdUv1q+1HtPjbgdTkAAAAA8ggBLg99/oZlCgf97AsHAAAA4AwEuDxUUx7SZ69bql+8c0R7jg96XQ4AAACAPEGAy1NfuGGZSgN+PfwCXTgAAAAASQS4PFU3r0SfvXaJ/uWtw9rXTRcOAAAAAAEur33ho8sVCvj08At7vS4FAAAAQB4gwOWxhooSfebqJfrZW4d1oGfI63IAAAAAeIwAl+e+eONyBXzGXDgAAAAAuQU4M7vNzHaZ2R4z+0qW63eY2Ttm9paZdZjZDRnX9pvZH8avTWfxc8H8ylLdta5FP33jsA6dHPa6HAAAAAAemjTAmZlf0sOSbpe0RtJdZrZmwrDnJF3unLtC0uckPTbh+nrn3BXOufaLL3nuueemFfKZ6dsvMhcOAAAAmMty6cCtk7THObfPOReR9ISkOzIHOOcGnXMudVguyQnTpqmqVJ/6yGI9ue2QDveNeF0OAAAAAI/kEuAWSjqUcdyZOncGM/uEmb0n6ZdKduHGOUnPmNk2M7v7Yoqdy7508wpJ0ndeZC4cAAAAMFflEuAsy7mzOmzOuaecc6sk3SnpqxmXrnfOXankI5j3mtmNWb+J2d2p+XMd3d3dOZQ1tyyoLtO/bV+sH2/t1NFTdOEAAACAuSiXANcpaXHG8SJJR8412Dm3RdIKM6tPHR9JfT0u6SklH8nMdt+jzrl251x7Q0NDjuXPLV+6aYUSzukR5sIBAAAAc1IuAW6rpDYzW2ZmIUmflvR05gAzazUzS72+UlJIUo+ZlZtZRep8uaRbJW2fzg8wlyyuDeuTVy3Sj7Ye0rH+Ua/LAQAAADDLJg1wzrmYpPskbZa0U9KPnXM7zOweM7snNexPJW03s7eUXLHyU6lFTRolvWxmb0t6XdIvnXO/mYHPMWd8+eZWxRNO331pn9elAAAAAJhldnrxyPzR3t7uOjrYMu5c/vd/fls/f/uIfvvX6zW/otTrcgAAAABMMzPblm0btpw28kZ+uW99q6LxhP5uC104AAAAYC4hwBWgpfXluvOKhfqn1w7qxOCY1+UAAAAAmCUEuAJ174ZWjcXi+rvf0oUDAAAA5goCXIFa0TBP/8vlC/SPvzugk0MRr8sBAAAAMAsIcAXsvvWtGonG9b2X6cIBAAAAcwEBroC1NVbojz7UrH949YD6hunCAQAAAMWOAFfgHtjQpsGxmB5/+QOvSwEAAAAwwwhwBW5lU4Vuv7RJ339lv06NRL0uBwAAAMAMIsAVgfs3tGlgLKa/f2W/16UAAAAAmEEEuCKwZkGlbl3TqO+9vE/9o3ThAAAAgGJFgCsSD2xsU/9oTD94db/XpQAAAACYIQS4InHpwiptXDVfj738gQbHYl6XAwAAAGAGEOCKyAMb29Q3HNUPfrff61IAAAAAzAACXBG5fHG1bl7ZoMd++4GG6MIBAAAARYcAV2Tu39Cmk0MR/fD3B7wuBQAAAMA0I8AVmauW1OijbfV6dMs+jUTiXpcDAAAAYBoR4IrQgxvbdGKQLhwAAABQbAhwRah9aa2uW1Gn727Zp9EoXTgAAACgWBDgitQDG9vUPTCmH71+0OtSAAAAAEwTAlyRumZ5na5eVqtHXtpLFw4AAAAoEgS4IvbgxjYd6x/TP3cc8roUAAAAANOAAFfErl1Rp/YlNfr2i3s1FqMLBwAAABQ6AlwRMzM9+LE2HT01qie3dXpdDgAAAICLRIArcje01uvDLdX69gt7FYklvC4HAAAAwEUgwBU5M9MDG9t0uG9ET71JFw4AAAAoZAS4OeDmSxp0+aIqfeuFPYrG6cIBAAAAhSqnAGdmt5nZLjPbY2ZfyXL9DjN7x8zeMrMOM7sh13sx88a7cIdOjuhnbx72uhwAAAAAF2jSAGdmfkkPS7pd0hpJd5nZmgnDnpN0uXPuCkmfk/TYFO7FLNiwar4uXViph1/YoxhdOAAAAKAg5dKBWydpj3Nun3MuIukJSXdkDnDODTrnXOqwXJLL9V7MDjPTAxvatL9nWE+/fcTrcgAAAABcgFwC3EJJmTtBd6bOncHMPmFm70n6pZJduJzvxey4ZU2jVjdX6lvP71E84Sa/AQAAAEBeySXAWZZzZ/3r3zn3lHNulaQ7JX11KvdKkpndnZo/19Hd3Z1DWZiqZBeuVftODOkX79CFAwAAAApNLgGuU9LijONFks75r3/n3BZJK8ysfir3Oucedc61O+faGxoacigLF2LT2iatbKzQN+nCAQAAAAUnlwC3VVKbmS0zs5CkT0t6OnOAmbWamaVeXykpJKknl3sxu3w+0/0bW7Xn+KB+vf2o1+UAAAAAmIJJA5xzLibpPkmbJe2U9GPn3A4zu8fM7kkN+1NJ283sLSVXnfyUS8p67wx8DkzB7Zc2q3X+PH3zuT1K0IUDAAAACoadXjwyf7S3t7uOjg6vyyhq//LWYT34xFv6zmeu1O0fava6HAAAAAAZzGybc6594vmcNvJG8fn4ZQu0vL5cDz33Pl04AAAAoEAQ4OYov89034ZWvdc1oGd3HvO6HAAAAAA5IMDNYX9y+QItrQvroefeVz4+SgsAAADgTAS4OSzg9+ne9a3acaRfz7933OtyAAAAAEyCADfH3fnhhVpcW6Zv0IUDAAAA8h4Bbo4L+n269+ZWvd15Si/u7va6HAAAAADnQYCD/s2Vi7SwukwPPUsXDgAAAMhnBDgoFPDpy+tX6K1DfXp5zwmvywEAAABwDgQ4SJI+edUiNVeV0oUDAAAA8hgBDpKkkoBfX755hToO9Op3e3u8LgcAAABAFgQ4pP3b9sVqrCzRQ8+973UpAAAAALIgwCGtNOjXPTet0O8/OKnX9tGFAwAAAPINAQ5nuGtdixoqSvQNunAAAABA3iHA4QylQb++eONyvbq3R1v3n/S6HAAAAAAZCHA4y2euXqL6eSG6cAAAAECeIcDhLGUhv+6+cbl++/4JvXGw1+tyAAAAAKQQ4JDVZ65eotpyunAAAABAPiHAIavykoC+8NFlenFXt94+1Od1OQAAAABEgMN5fPbapaoOB+nCAQAAAHmCAIdzmlcS0BduWKbn3juu7YdPeV0OAAAAMOcR4HBen71uqSpLA3ThAAAAgDxAgMN5VZYG9bkblumZd4/p3SP9XpcDAAAAzGkEOEzqL65fpoqSgL75PF04AAAAwEsEOEyqqiyov7h+qX69vUu7uga8LgcAAACYswhwyMnnblimeSUBfYMuHAAAAOAZAhxyUh0O6T9ct0S/+sNRvX+MLhwAAADghZwCnJndZma7zGyPmX0ly/XPmNk7qT+vmtnlGdf2m9kfzOwtM+uYzuIxuz5/w3KVBf361gt7vC4FAAAAmJMmDXBm5pf0sKTbJa2RdJeZrZkw7ANJNznnLpP0VUmPTri+3jl3hXOufRpqhkdqy0P682uX6OdvH9He7kGvywEAAADmnFw6cOsk7XHO7XPORSQ9IemOzAHOuVedc72pw9ckLZreMpEv/vKjy1US8Ovh5+nCAQAAALMtlwC3UNKhjOPO1Llz+bykX2ccO0nPmNk2M7t76iUin9TPK9G/v6ZFP3vrsPafGPK6HAAAAGBOySXAWZZzLutAs/VKBri/zjh9vXPuSiUfwbzXzG48x713m1mHmXV0d3fnUBa88pc3LlfQ79PDzIUDAAAAZlUuAa5T0uKM40WSjkwcZGaXSXpM0h3OuZ7x8865I6mvxyU9peQjmWdxzj3qnGt3zrU3NDTk/gkw6+ZXlOozVy/RT988rIM9w16XAwAAAMwZuQS4rZLazGyZmYUkfVrS05kDzKxF0k8l/blzbnfG+XIzqxh/LelWSdunq3h454s3LZffZ/r2i3ThAAAAgNkyaYBzzsUk3Sdps6Sdkn7snNthZveY2T2pYf+XpDpJ356wXUCjpJfN7G1Jr0v6pXPuN9P+KTDrGitLdddHFuvJbZ3q7KULBwAAAMwGcy7rdDZPtbe3u44OtozLd0dPjeim//6iPtm+SP/lEx/yuhwAAACgaJjZtmzbsOW0kTeQTXNVmf7dRxbpnzsO6UjfiNflAAAAAEWPAIeL8qWbWyVJj7y01+NKAAAAgOJHgMNFWVhdpk9etUhPvH5IXadGvS4HAAAAKGoEOFy0L9/cqoRzdOEAAACAGUaAw0VbXBvWv7lyoX70+kEd76cLBwAAAMwUAhymxb3rWxVLOH13yz6vSwEAAACKFgEO02JJXbnuvGKhfvj7A+oeGPO6HAAAAKAoEeAwbe5dv0KRWEKP/ZYuHAAAADATCHCYNssb5ulPLl+gH/zugHoG6cIBAAAA040Ah2l134Y2jcbieuzlD7wuBQAAACg6BDhMq9b58/TxyxboB6/uV+9QxOtyAAAAgKJCgMO0u39Dq4ajcX2PLhwAAAAwrQhwmHaXNFbojy5t1t+/ul+nhqNelwMAAAAUDQIcZsR9G1o1OBbT46/QhQMAAACmCwEOM2J1c6U2rW3U4698oFMjdOEAAACA6UCAw4x5YGObBkZj+odX93tdCgAAAFAUCHCYMWsXVOljqxv1vZc/0MAoXTgAAADgYhHgMKMe3NimUyNR/eB3B7wuBQAAACh4BDjMqA8tqtKGVfP12G/3aWgs5nU5AAAAQEEjwGHG3b+hVb3DUf3ja3ThAAAAgItBgMOM+3BLjW68pEF/t2WfhiN04QAAAIALRYDDrHhwY5t6hiL64WsHvS4FAAAAKFgEOMyKq5bU6IbWen13yz6NROJelwMAAAAUJAIcZs0DG9t0YnBMP3qdLhwAAABwIQhwmDXrltXqmuW1euSlvRqN0oUDAAAApooAh1n14MZLdHxgTP9z6yGvSwEAAAAKTk4BzsxuM7NdZrbHzL6S5fpnzOyd1J9XzezyXO/F3HLN8lqtW1qr77y4V2MxunAAAADAVEwa4MzML+lhSbdLWiPpLjNbM2HYB5Jucs5dJumrkh6dwr2YQ8xMD2xsU1f/qH7c0el1OQAAAEBByaUDt07SHufcPudcRNITku7IHOCce9U515s6fE3SolzvxdxzfWudrlpSo6/98l196Z+26V/eOqz+0ajXZQEAAAB5L5DDmIWSMicsdUq6+jzjPy/p1xd4L+YAM9PXP3WFHnlpr55595h+vb1LQb/p2hX12rS2UbesadT8ilKvywQAAADyTi4BzrKcc1kHmq1XMsDdcAH33i3pbklqaWnJoSwUssW1YX3tEx/SV++4VG8e6tMzO7q0eUeX/uap7frPP9uuq1pqtGltkzatbVJLXdjrcgEAAIC8kEuA65S0OON4kaQjEweZ2WWSHpN0u3OuZyr3SpJz7lGl5s61t7dnDXkoPj6f6aolNbpqSY2+cvsq7To2oM3bj2nzji597Vc79bVf7dSqpop0mFvdXCGzbL8XAAAAAIqfOXf+rGRmAUm7JW2UdFjSVkl/5pzbkTGmRdLzkj7rnHt1Kvdm097e7jo6Oi7oA6F4HDo5rM07uvTMjmPaeuCknJNaasPatLZRm9Y26cqWGvl8hDkAAAAUHzPb5pxrP+v8ZAEudfMfSfq6JL+kx51zXzOzeyTJOfeImT0m6U8lHUjdEhv/Ztnunez7EeAwUffAmJ7dmezMvbqnR5F4QvXzSnTLmkZtWtuo61bUKxRgW0MAAAAUh4sKcLONAIfzGRiN6oVd3dq8o0svvndcQ5G4KkoD2rBqvjatbdJNlzSovCSXp4MBAACA/ESAQ1Eajcb1yp4T2ryjS8/uPK6TQxGVBHz6aFu9bl3bpI+tblRtecjrMgEAAIApOVeAo02BglYa9Gvj6kZtXN2oWDyhjgO96Xlzz+48Lr/PtG5prTatbdSta5u0oLrM65IBAACAC0YHDkXJOafth/u1ObU9wfvHByVJly2qSq9o2Tp/nsdVAgAAANnxCCXmtH3dg9q8I7kIyluH+iRJKxrK02HuskVVbE8AAACAvEGAA1K6To3qmXeTnbnX9p1UPOHUXFWqW9c0atOlTVq3tFYBPytaAgAAwDsEOCCLvuGIntt5XJt3dOml3d0aiyVUEw5q4+rkXnMfbatXadDvdZkAAACYYwhwwCSGIzFt2d2tzTuO6bmdx9Q/GlM45NdNlzTotkubtH7VfFWWBr0uEwAAAHMAq1ACkwiHArrt0mbddmmzovGEXtvXk17R8tfbuxT0m65dUa9Naxt1y5pGza8o9bpkAAAAzDF04IBJJBJObx7q0zOpFS339wzLTLqypUa3pRZBaakLe10mAAAAigiPUALTwDmn3ccGtXlHl36zvUvvHu2XJK1qqkivaLm6uYIVLQEAAHBRCHDADDh0cjj9mOXWAyflnNRSG9amtclFUK5sqZHPR5gDAADA1BDggBl2YnBMz76b3GvulT09isQTqp9XolvWNGrT2kZdt6JeoQDbEwAAAGByBDhgFg2MRvXCrm5t3tGlF987rqFIXBWlAW1YNV+b1jbppksaVF7CGkIAAADIjgAHeGQ0Gtere09o8/Zj+tedx3RyKKJQwKcb2+p169omfWx1o2rLQ16XCQAAgDxCgAPyQCyeUMeB3vS8ucN9I/L7TOuW1mrT2kbdurZJC6rLvC4TAAAAHiPAAXnGOacdR/rTK1q+f3xQknTZoqrUipaNap1f4XGVAAAA8AIBDshz+7oHtXlHchGUtw71SZJWNJSntye4bFEV2xMAAADMEQQ4oIB0nRrVv77bpd/s6NJr+04qnnBqrirVrWuS2xOsW1argJ8VLQEAAIoVAQ4oUH3DET2387g27+jSlve7NRpNqCYc1MbVyTD30bZ6lQb9XpcJAACAaUSAA4rAcCSmLbtP6JkdXXp25zH1j8YUDvl10yUN2rS2Se1La7SwuoxHLQEAAArcuQIcG1EBBSQcCui2S5t026VNisYTem1fT3pFy19v75IkVZQEtLKpQiubKrSquVKrUq8rS4MeVw8AAICLRQcOKAKJhNMfDp/S9iOntKtrQO91Dei9o/3qH42lxyysLkuGulSgW91cqWX15Qoylw4AACDv0IEDipjPZ7p8cbUuX1ydPuecU1f/qN47mgp0Xf3a1TWg377frWg8+YubkN+n5Q3lWt1cmQ53q5oq1VhZwmOYAAAAeYgABxQpM1NzVZmaq8q0ftX89PlILKF9Jwa1q2tAO48OaFdXv17b16On3jycHlMdDmplYyrQpcLdysYKlZfwVwYAAICX+NcYMMeEAj6taqrUqqZK3XHF6fOnhqPadSzZqRt/BPMnbxzW4NiB9JiW2nDy8cumCq1sqtSq5gotrSuX30e3DgAAYDYQ4ABIkqrCQa1bVqt1y2rT55xz6uwd0XtdyU7dzq4B7eoa0HM7jymRmj5bEvDpksaKMx7BXNlUoYaKEo8+CQAAQPHKKcCZ2W2SHpLkl/SYc+6/Tri+StL3JV0p6W+cc3+bcW2/pAFJcUmxbBPxAOQnM9Pi2rAW14Z1y5rG9PnRaFx7jg+mFkxJduxe2t2tJ7d1psfUzwulHr1MdupWNVWobX6FykLsWQcAAHChJg1wZuaX9LCkWyR1StpqZk87597NGHZS0gOS7jzH26x3zp24yFoB5InSoF+XLqzSpQurzjjfMziWXgVzPNz96PWDGonGJUk+k5bWlWtVczLYJVfDrNDimrB8PIYJAAAwqVw6cOsk7XHO7ZMkM3tC0h2S0gHOOXdc0nEz++MZqRJAQaibV6LrWkt0XWt9+lw84XTw5HDyEcyjA+nFU369vUvju5iEQ35d0nh6i4PkHL0K1ZSHPPokAAAA+SmXALdQ0qGM405JV0/hezhJz5iZk/Rd59yjU7gXQIHz+0zL6su1rL5ct13anD4/HInp/WOD6Ucwd3UN6Jl3j+mJraf/ummsLNHKpsrUoinJP63z56kkwGOYAABgbsolwGV7rmkqu39f75w7YmbzJf2rmb3nnNty1jcxu1vS3ZLU0tIyhbcHUIjCoUDWveu6B8f03ninLrV33fdf7VEklpCUDITL68u1qjnZpVvZWKFVzRVaWF3G3nUAAKDo5RLgOiUtzjheJOlIrt/AOXck9fW4mT2l5COZZwW4VGfuUUlqb2+fSkAEUCTMTPMrSjW/olQ3XtKQPh+LJ7S/Zyi1vUFyjt2bB3v187dP/1VUURJId+nS4a6pQpWlQS8+CgAAwIzIJcBtldRmZsskHZb0aUl/lsubm1m5JJ9zbiD1+lZJ/8+FFgtgbgr4fWqdX6HW+RX6+GWnzw+MRrU79RjmrlS4+/nbR/TD3x9Mj1lYXZbe4mB8ft3yhnIF/T4PPgkAAMDFmTTAOediZnafpM1KbiPwuHNuh5ndk7r+iJk1SeqQVCkpYWZ/JWmNpHpJT6UeawpI+v+cc7+ZkU8CYM6pKA3qqiU1umpJTfqcc05HT42mV8McD3dbdncrltq8Lug3rWiYp9XNlWfsX9dYWcJjmAAAIK+Zc/n3tGJ7e7vr6OjwugwARSQSS2jficH0Kpi7UounHD01mh5TVRZMhbkKrWyq1MqmeVpSV6668hDBDgAAzCoz25ZtD+2cNvIGgEIXCvhS2xNU6o4rTp8/NRxNdumOnQ52T27r1FAknh5THvKrpa5cS+vCaqkLa0ltuZbUhbWkLqzmqjL52cMOAADMEgIcgDmtKhzU1cvrdPXyuvS5RMLpcN+I3j8+oAM9wzrQM5zcy+7YgJ7beVyReCI9NuT3aVFNWSrQlaulNpx+vbi2jC0PAADAtCLAAcAEPp9pcW1Yi2vDZ12LJ5y6+kd1oGcoI9wNaf+JYW3d36vBsVh6rJnUXFmqlrqwltaVn9G9a6kLs0ImAACYMgIcAEyB32daWF2mhdVlum7Fmdecczo5FNGBk8PpgHewZ1gHTg7r2Z3HdGIwcsb42vLQGR27JRmv6+cx7w4AAJyNAAcA08TMVDevRHXzSnRlS81Z1wfHYslA1zOUCnnJ7l3H/uSedomMNaXCIX863E3s3i2oZt4dAABzFQEOAGbJvJKA1iyo1JoFlWddi8QS6uxNdusOnEgGvIM9w9rbPaQXdnUrEjs97y7oNy2qCWft3i2uDas0yLw7AACKFQEOAPJAKODT8oZ5Wt4wT1p55rVEet5dar5d+tHMIb1xoFcDGfPuJKm5qvTMcJfq3rXUhVVVxrw7AAAKGQEOAPKcz2daUF2mBdVlunZF3RnXnHPqHY7qQM+QDp4c1v4TyWB3sGdYL+zqVvdA5xnja8JBtUyYb5cMeGE1VLCROQAA+Y4ABwAFzMxUWx5SbXlIH84y725oLKaDqfl243PvDvYM681DvfrFO2fOuysL+pMrZGbp3i2oLlXA75vFTwYAALIhwAFAESsvCWh1c6VWN2efd3e4b+SM7t3Bk0P64MSQXtrdrbGMeXcBn2lRTVnW7l0L8+4AAJg1BDgAmKNCAZ+W1ZdrWX35WdcSCadjA6PprRD2Z3bvDvZqYPTMeXeNlSVnbYUw3r2rCjPvDgCA6UKAAwCcxeczNVeVqbmqTNcsP3veXd9wNL3fXTLgJbt3L+3u1vGBsTPGV5UFtbQurJa6cs2vKFF1WVDV4aCqwyFVh4OqCYdUlTo3ryTAPDwAAM6DAAcAmBIzU015SDXlIV2xuPqs68OR0/Puxrt3B08O6+1DfeoZHNNQJH7O9w747HS4ywx6E0JfdVnqayoAhkN+gh8AYE4gwAEAplU4FNCqpkqtajp73p2UnHvXNxLRqeGoeoej6huOqG8k9TV17tRI8vXhvlG9e6RffSNRDZ8n+AX9pqqykGpSoS7zdWboqwkHVZU6VxMOqixI8AMAFBYCHABgVoUCPs2vKNX8itIp3Tcajat/ZPLQ1zscUWfvsHYcSb4ejSbO+Z4hvy/dycvs6mUNfWUh1ZQnv5aFWLQFAOANAhwAoCCUBv0qDfo1v3Lqwa9vOKq+VMDLDH2nO4HJcwdPDuvtzoh6h6OKxM4d/EoCvrNDX1lI1eWnz9WMdwIzzrFaJwDgYhHgAABFrTToV1OVX01VUwt+I5F4OvT1Dmc88jkh9PWNRPXBiSH1DfepbziqSPzcwa806Dsr9NWUJ4NeZugbn9s3Pq4kQPADACQR4AAAyKIs5FdZKLkSZ66ccxpJdfwmhr6J3b9TIxHt7R5U74Hk62jcnfN9y4L+cy7oUlEaUGVpUJVlp19XlQVUURpUZWlQpUEf8/wAoIgQ4AAAmCZmpnAooHAooAXVUwt+w5H46a7eeUJf33BUu48N6NRIclwsce7gJyVX9qwsC6qyNBXqygKqKEl+rSwNnj5XmhyTGQQrS4OaVxqQ30cABIB8QYADAMBjZqbykoDKSwJaVJP7fc45jUYT6h+Nqn8kqv7RmPpHoxoYjaWOT78eSF3rH4nqeP9Y+tr5VvccV1ESSIa6suwdv/Frma9PjwvwCCgATCMCHAAABcrMUo96+tU4xcVdxkXjCQ2mw11MA6PJ4NefDoGpcxnXjp4a1e7jA+lzkzQBFQr4Uh29gCpS3cDK83X+JoRDNngHgNMIcAAAzGFBvy+9MfuFcM5pKBJPh7xkZy8z8MWydgeP9I2kr42dZ8VPSfKZVFF6/o7f2aHwzG5g0O+7oM8HAPmGAAcAAC6YmWleSbJL1lx1Ye8xFotrYDQ26aOf6dejMR06OZw+HhyLyU3SBSwL+rN0/M6cG1hZGlQ45E9tWeFTacCvkqBfJQHf6XOp7SxKAz4FCIUAPECAAwAAnioJ+FUyz6/6eSUXdH8i4TQYyQh8Zzz6eTr4DWQ8Kto7FNGBnmENjEZ1aiR63lVAzyXgM5VmBLySVOg7I+gFfSoJZH7NuH7WfanXQX/6fUoyxpUG/SwoA4AABwAACpvPZ+lVMy+Ec05jsYT6R6IajsQ1GotrLJrQaDSu0VjqazR1LhZPHSc0Fkt+HT9O3nf63MBoLPUep8+NRRPn3StwMkG/pTuDyVB4ZlicGATT188IluPXzh06SzOu+QiNQF4hwAEAgDnNzNLhZTbEEy4d/s4MgWcHwWwhMR0goxlhMzWubzh6VsAci8UvqMM4LuT3ne4MZobEjEdLS7KExJKAX+GQX9XhkGrLg6otL1FtOKTaeSGVh/wsTANcIAIcAADALPL7xvcLnL3vGYsnNBY7u6s4HgRPX5sQGDNCYrawORSJqWdownukvkf8PMuThvw+1YyHuvKgasIh1aUW0xn/Oh72asMhVYdDCgWYcwhIOQY4M7tN0kOS/JIec8791wnXV0n6vqQrJf2Nc+5vc70XAAAAMyvgTy66Ul4ye7+7j8YTGo7E1TccUc9QRL1DEZ0c/zN85vHh3lM6ORRR/2jsnO9XURpQbXko+SecJeyNv079qSxl+wkUp0n/KzYzv6SHJd0iqVPSVjN72jn3bsawk5IekHTnBdwLAACAIhP0+1RV5lNVWVBL6spzuicaT6h3OKLeoeg5w97JoYiOnhrVu0f71TMUUeQc21AEfJYOdzXlQdWVlyS7fhlhL32uPKSacGjWHqMFLkYuv4ZZJ2mPc26fJJnZE5LukJQOYc6545KOm9kfT/VeAAAAQEqGvvkVpZpfkdvG9M45DUfi5w17J4ci6h2OaGdXv3qHIuobiZ5z24nykP+cnb26CR2+2nBIVWVBFnnBrMslwC2UdCjjuFPS1Tm+f873mtndku6WpJaWlhzfHgAAAHOVmam8JKDykoAW14ZzuieecOobToa6k0NRnRwaO+Nrb+qRz57BiN4/NqiTQxGNRONZ38tnUk04I9hNfLQzY/GW8S5gWYguHy5OLgEu268Vcl3KKOd7nXOPSnpUktrb2y98qSQAAADgHPw+U928EtVNYd/BkUg8a3evd8L8vr3dg+o9kHx9rjVcSoO+9KObmYu3ZC7aMt7lq0k92sn+f8iUS4DrlLQ443iRpCM5vv/F3AsAAAB4rizk18JQmRZWl+U0PpFw6h+NnifspTp+w1Ht7xlS71BUg2PZF3Axk6rKTs/dC5cEFPSZ/D5TwG8K+HwKTDj2+0wBnyUXr0ldC/pNfl/244A/dX/q3mDGcfI9sx8H/Wd+r/HX49d8JhaSmQG5BLitktrMbJmkw5I+LenPcnz/i7kXAAAAKDg+n6k6tf3B8obc7hmLxdOLt0zs7I3P7zs5GNGpkajiiYRicadYwimecIpNOI7GE6nzyePzbekw0wJZg2XGcToQ+jKuZQmIPp/8mdcyjseDZebYM49N/lSYPR1YfekgfOWSGtVPoSPrtUkDnHMuZmb3Sdqs5FYAjzvndpjZPanrj5hZk6QOSZWSEmb2V5LWOOf6s907Q58FAAAAKEglAb+aqvxqqsptAZepSCSc4u7scJcMfdmPo/HT4TCevuYUTyQyrp19HIsnMoJl8jjbteQ92Y9jqfcbisXOqiXzemy8vvjp1xeyaf0/fG6dbrokx6SdB8ydaxkeD7W3t7uOjg6vywAAAABQYBIZ4S6WcIrHnaIZQXRiGFxSF1ZFadDrss9iZtucc+0Tz8/ebo4AAAAAMMN8PlPIZwrJ53UpM6I4PxUAAAAAFCECHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAhzznldw1nMrFvSAa/ryKJe0gmviwDOg59R5Dt+RpHv+BlFvuNndO5Y4pxrmHgyLwNcvjKzDudcu9d1AOfCzyjyHT+jyHf8jCLf8TMKHqEEAAAAgAJBgAMAAACAAkGAm5pHvS4AmAQ/o8h3/Iwi3/EzinzHz+gcxxw4AAAAACgQdOAAAAAAoEAQ4HJgZreZ2S4z22NmX/G6HiCTmS02sxfMbKeZ7TCzB72uCcjGzPxm9qaZ/cLrWoCJzKzazJ40s/dSf59e63VNQCYz+99S/5/fbmY/MrNSr2uCNwhwkzAzv6SHJd0uaY2ku8xsjbdVAWeISfqPzrnVkq6RdC8/o8hTD0ra6XURwDk8JOk3zrlVki4XP6vII2a2UNIDktqdc5dK8kv6tLdVwSsEuMmtk7THObfPOReR9ISkOzyuCUhzzh11zr2Rej2g5D86FnpbFXAmM1sk6Y8lPeZ1LcBEZlYp6UZJ35Mk51zEOdfnaVHA2QKSyswsICks6YjH9cAjBLjJLZR0KOO4U/zjGHnKzJZK+rCk33tcCjDR1yX9H5ISHtcBZLNcUrek76ce833MzMq9LgoY55w7LOlvJR2UdFTSKefcM95WBa8Q4CZnWc6xdCfyjpnNk/QTSX/lnOv3uh5gnJl9XNJx59w2r2sBziEg6UpJ33HOfVjSkCTmvCNvmFmNkk+ALZO0QFK5mf17b6uCVwhwk+uUtDjjeJFoWSPPmFlQyfD2Q+fcT72uB5jgekl/Ymb7lXwMfYOZ/ZO3JQFn6JTU6Zwbf3rhSSUDHZAvPibpA+dct3MuKumnkq7zuCZ4hAA3ua2S2sxsmZmFlJww+rTHNQFpZmZKztvY6Zz7H17XA0zknPs/nXOLnHNLlfw79HnnHL85Rt5wznVJOmRmK1OnNkp618OSgIkOSrrGzMKp/+9vFAvtzFkBrwvId865mJndJ2mzkiv+PO6c2+FxWUCm6yX9uaQ/mNlbqXP/yTn3K+9KAoCCc7+kH6Z+WbtP0l94XA+Q5pz7vZk9KekNJVefflPSo95WBa+Yc0znAgAAAIBCwCOUAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCD+f8y14DV/AZWvAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 1080x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "### Plotting the cost vs epochs\n",
    "fig, ax = plt.subplots(figsize=(15, 5))\n",
    "plt.plot(np.array(cost))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Assignment1-Step2_Harshit_Agarwal.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "037faad73a8a40499555406f4e586731": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0554687d97424798a86ea0a4c56cdbf8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "14d009d293864f9f85bd16b5e1b6b381": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c75452bfc2264392bf3d842bfbd4eeee",
      "placeholder": "​",
      "style": "IPY_MODEL_0554687d97424798a86ea0a4c56cdbf8",
      "value": " 29696/? [00:00&lt;00:00, 395418.47it/s]"
     }
    },
    "1a89067cd68c41fb96b74e0ebf3d1931": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_91798827c9254ffc969228862c8ee37f",
       "IPY_MODEL_cdbc8de66e8c4fa79100dcac1fcfed4d",
       "IPY_MODEL_2773f678ca4a45158df219d2672fa646"
      ],
      "layout": "IPY_MODEL_9f7585b1023d4e20ba3649efcfcdb881"
     }
    },
    "2773f678ca4a45158df219d2672fa646": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_480e369fb79f4326bd13d2355fddd890",
      "placeholder": "​",
      "style": "IPY_MODEL_898cdc388a2d428ca742a9ba2365df20",
      "value": " 5120/? [00:00&lt;00:00, 57213.75it/s]"
     }
    },
    "2ee0648f054c49049fdf5d6ac81ec086": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4be171eff56046c2a401bf02b6d704c9",
       "IPY_MODEL_c8c6df38201b470bb06bffc677873f93",
       "IPY_MODEL_f51f06fbe28746dea09e075256e29451"
      ],
      "layout": "IPY_MODEL_7c90cc0c4c3b4cd89dbdfffdb78d463d"
     }
    },
    "3d3af4241b714b9a87253ac87b2e31b1": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "44f9d2dde4424058a337a1e3d585b5fe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "47b04fa0a67c4fbbbcdc3d1062430259": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "480e369fb79f4326bd13d2355fddd890": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4b80b4415ce34130a048933b2df179a5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4be171eff56046c2a401bf02b6d704c9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9ee4f57a58c14bb1880789196ff7ef63",
      "placeholder": "​",
      "style": "IPY_MODEL_74d1d431fbb84c4f948ca510e397f4e2",
      "value": ""
     }
    },
    "51567646aecc4126a3b9cb96f97d5be5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_037faad73a8a40499555406f4e586731",
      "placeholder": "​",
      "style": "IPY_MODEL_4b80b4415ce34130a048933b2df179a5",
      "value": " 9913344/? [00:00&lt;00:00, 33092818.93it/s]"
     }
    },
    "54ce0f87fc1f46d2b9f9850a393b46cf": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "553cf4ae8edb4b48a58ab4446036ae87": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5ecbc7b6eb6544709da61283dfc8d3c6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "6e770d6f5e5f4101b2c985f8182eb79e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_af12215b0ba24c699e1111c132dee3fe",
      "placeholder": "​",
      "style": "IPY_MODEL_908f60d17e1b49869ac3d4ee53a74a6e",
      "value": ""
     }
    },
    "74d1d431fbb84c4f948ca510e397f4e2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "780790f42e764ec5a0b8d441a290e5d3": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "7837b9def9dd470cb57f96b38afbfc18": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "7c8cf011ed684ef0841d41da76c9bfc8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d2e31ba6e42448839959f523cc56dbcb",
      "max": 28881,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_5ecbc7b6eb6544709da61283dfc8d3c6",
      "value": 28881
     }
    },
    "7c90cc0c4c3b4cd89dbdfffdb78d463d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "84bda6e8c38b4b669e2825c962e9dbfb": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "898cdc388a2d428ca742a9ba2365df20": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "908f60d17e1b49869ac3d4ee53a74a6e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "90f550d2f7344eb894208d0373ac6f8d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "91798827c9254ffc969228862c8ee37f": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3d3af4241b714b9a87253ac87b2e31b1",
      "placeholder": "​",
      "style": "IPY_MODEL_84bda6e8c38b4b669e2825c962e9dbfb",
      "value": ""
     }
    },
    "9ee4f57a58c14bb1880789196ff7ef63": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9f7585b1023d4e20ba3649efcfcdb881": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a1dd0d1a7fb14235a271992bea0d233c": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "af12215b0ba24c699e1111c132dee3fe": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b3105726704240439adcf7f13bd48cca": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b9fa1995789a49fdb6a0c3e8505733de": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "c52318a206de4565a725388bc50bdfc0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_d054c2e9330c433ab6bb900c6fa7dac6",
       "IPY_MODEL_7c8cf011ed684ef0841d41da76c9bfc8",
       "IPY_MODEL_14d009d293864f9f85bd16b5e1b6b381"
      ],
      "layout": "IPY_MODEL_553cf4ae8edb4b48a58ab4446036ae87"
     }
    },
    "c75452bfc2264392bf3d842bfbd4eeee": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c8c6df38201b470bb06bffc677873f93": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_90f550d2f7344eb894208d0373ac6f8d",
      "max": 1648877,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7837b9def9dd470cb57f96b38afbfc18",
      "value": 1648877
     }
    },
    "cdbc8de66e8c4fa79100dcac1fcfed4d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_780790f42e764ec5a0b8d441a290e5d3",
      "max": 4542,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_47b04fa0a67c4fbbbcdc3d1062430259",
      "value": 4542
     }
    },
    "d054c2e9330c433ab6bb900c6fa7dac6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a1dd0d1a7fb14235a271992bea0d233c",
      "placeholder": "​",
      "style": "IPY_MODEL_b9fa1995789a49fdb6a0c3e8505733de",
      "value": ""
     }
    },
    "d2e31ba6e42448839959f523cc56dbcb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d784c3bdbaf543a299447b17b500c2a8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f6af65953a9841e1907a47fb27c01b96",
      "max": 9912422,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_44f9d2dde4424058a337a1e3d585b5fe",
      "value": 9912422
     }
    },
    "f1ae475d1e48411ab7cb9f49f6f673c9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_6e770d6f5e5f4101b2c985f8182eb79e",
       "IPY_MODEL_d784c3bdbaf543a299447b17b500c2a8",
       "IPY_MODEL_51567646aecc4126a3b9cb96f97d5be5"
      ],
      "layout": "IPY_MODEL_f65a4ec103c34c7186fb76d3e3507795"
     }
    },
    "f51f06fbe28746dea09e075256e29451": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b3105726704240439adcf7f13bd48cca",
      "placeholder": "​",
      "style": "IPY_MODEL_54ce0f87fc1f46d2b9f9850a393b46cf",
      "value": " 1649664/? [00:00&lt;00:00, 2122626.57it/s]"
     }
    },
    "f65a4ec103c34c7186fb76d3e3507795": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f6af65953a9841e1907a47fb27c01b96": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
关于此算法

卷积神经网络 (CNN)

资源

CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network
Pytorch : https://pytorch.ac.cn/tutorials/beginner/basics/intro.html

首先使用 DataLoader 和 torchvision 提供的数据集导入训练和测试 MNIST 数据集。您可以将训练和测试批次大小设置为您认为最合适的任何值。

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader

import matplotlib.pyplot as plt
import numpy as np
# Downloading MNIST dataset from Pytorch
dataset = datasets.MNIST(
    root="./data",
    download=True,
    train=True,
    transform=transforms.ToTensor(),
)
# Splitting the dataset into training and testing set
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])
# Visualizing a sample from dataset
plt.imshow(train_dataset.dataset.data[10])
plt.title("Label : " + str(train_dataset.dataset.train_labels[10].item()))
/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets
  warnings.warn(&quot;train_labels has been renamed targets&quot;)
Text(0.5, 1.0, &#x27;Label : 3&#x27;)
# Creating a DataLoader for training and testing
train = DataLoader(train_dataset, batch_size=32, shuffle=True)
test = DataLoader(test_dataset, batch_size=1, shuffle=True)

定义具有以下架构的网络

Conv2d (输入通道=1,输出通道 = 15,内核大小 = 5) → MaxPool (内核大小 = 2) → ReLU → Conv2d (输入通道=15,输出通道 = 30,内核大小 = 5) → Dropout2d (p = 0.5) → MaxPool (内核大小 = 2) → ReLU → Linear(输入维度 = 480,隐藏单元 = 64) → ReLU → Dropout (p=0.5) → Linear(输入维度 = 64,隐藏单元 = 10) → LogSoftMax

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 15, kernel_size=5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(15, 30, kernel_size=5),
            nn.Dropout2d(0.5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(480, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        return self.cnn(x)

在 MNIST 上训练您在上一个问题中定义的网络,使用您认为合适的优化器和训练轮数。使用交叉熵损失。每个 epoch 在测试数据集上测试您的模型,并打印您获得的准确率值。

batch_size = 32

model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Running the model on GPU if available
## Refer pytorch documentation for more details about copying model and data onto the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

cost = []
epochs = 10

# Training the model
for epoch in range(epochs):

    loss_epoch = []
    train_acc = []

    for x, y in train:

        # Predicting the output
        y_pred = model(x.to(device))

        # Converting the predicted output from one hot encoding to a single number
        _, t_preds = torch.max(y_pred, dim=1)

        # Calculating the training accuracy
        train_acc.append(
            torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))
        )

        # Calculating the loss
        loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))

        # Backpropagation

        # Zeroing the gradients
        optimizer.zero_grad()

        # Calculating the gradients
        loss.backward()

        # Updating the weights
        optimizer.step()

        # Appending the loss of each batch to the epoch loss
        loss_epoch.append(loss.item())

    # Calculating test accuracy
    with torch.no_grad():
        if epoch % 1 == 0:
            test_acc = []
            for x, y in test:
                y_pred = model(x.to(device))
                _, t_preds = torch.max(y_pred, dim=1)
                test_acc.append(
                    torch.tensor(
                        torch.sum(t_preds == y.to(device)).item() / len(t_preds)
                    )
                )

            print(
                "Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}".format(
                    epoch,
                    sum(loss_epoch) / len(loss_epoch),
                    sum(train_acc) / len(train_acc),
                    sum(test_acc) / len(test_acc),
                )
            )

    cost.append(sum(loss_epoch) / len(loss_epoch))
/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891
Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743
Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391
Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097
Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349
Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978
Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738
Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843
Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386
Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913
### Plotting the cost vs epochs
fig, ax = plt.subplots(figsize=(15, 5))
plt.plot(np.array(cost))
[&lt;matplotlib.lines.Line2D at 0x7fa94052aa00&gt;]