{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wq_TKqjUmILg"
},
"source": [
"# Convolutional Neural Network (CNN)\n",
"\n",
"## Resources\n",
"\n",
" CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network\n",
" Pytorch : https://pytorch.ac.cn/tutorials/beginner/basics/intro.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Start by importing both the training and testing MNIST datasets using DataLoaders and the torchvision provided datasets. You can set both the training and testing batch size to be whatever you feel is best."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "qGEvJYHnmILh"
},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import datasets, transforms\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data.dataloader import DataLoader\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://127.0.0.1:8080/",
"height": 437,
"referenced_widgets": [
"f1ae475d1e48411ab7cb9f49f6f673c9",
"6e770d6f5e5f4101b2c985f8182eb79e",
"d784c3bdbaf543a299447b17b500c2a8",
"51567646aecc4126a3b9cb96f97d5be5",
"f65a4ec103c34c7186fb76d3e3507795",
"af12215b0ba24c699e1111c132dee3fe",
"908f60d17e1b49869ac3d4ee53a74a6e",
"f6af65953a9841e1907a47fb27c01b96",
"44f9d2dde4424058a337a1e3d585b5fe",
"037faad73a8a40499555406f4e586731",
"4b80b4415ce34130a048933b2df179a5",
"c52318a206de4565a725388bc50bdfc0",
"d054c2e9330c433ab6bb900c6fa7dac6",
"7c8cf011ed684ef0841d41da76c9bfc8",
"14d009d293864f9f85bd16b5e1b6b381",
"553cf4ae8edb4b48a58ab4446036ae87",
"a1dd0d1a7fb14235a271992bea0d233c",
"b9fa1995789a49fdb6a0c3e8505733de",
"d2e31ba6e42448839959f523cc56dbcb",
"5ecbc7b6eb6544709da61283dfc8d3c6",
"c75452bfc2264392bf3d842bfbd4eeee",
"0554687d97424798a86ea0a4c56cdbf8",
"2ee0648f054c49049fdf5d6ac81ec086",
"4be171eff56046c2a401bf02b6d704c9",
"c8c6df38201b470bb06bffc677873f93",
"f51f06fbe28746dea09e075256e29451",
"7c90cc0c4c3b4cd89dbdfffdb78d463d",
"9ee4f57a58c14bb1880789196ff7ef63",
"74d1d431fbb84c4f948ca510e397f4e2",
"90f550d2f7344eb894208d0373ac6f8d",
"7837b9def9dd470cb57f96b38afbfc18",
"b3105726704240439adcf7f13bd48cca",
"54ce0f87fc1f46d2b9f9850a393b46cf",
"1a89067cd68c41fb96b74e0ebf3d1931",
"91798827c9254ffc969228862c8ee37f",
"cdbc8de66e8c4fa79100dcac1fcfed4d",
"2773f678ca4a45158df219d2672fa646",
"9f7585b1023d4e20ba3649efcfcdb881",
"3d3af4241b714b9a87253ac87b2e31b1",
"84bda6e8c38b4b669e2825c962e9dbfb",
"780790f42e764ec5a0b8d441a290e5d3",
"47b04fa0a67c4fbbbcdc3d1062430259",
"480e369fb79f4326bd13d2355fddd890",
"898cdc388a2d428ca742a9ba2365df20"
]
},
"id": "N_61-p6ymILj",
"outputId": "765b9e09-2ce2-47ab-eac2-6978fa820170"
},
"outputs": [],
"source": [
"# Downloading MNIST dataset from Pytorch\n",
"dataset = datasets.MNIST(\n",
" root=\"./data\",\n",
" download=True,\n",
" train=True,\n",
" transform=transforms.ToTensor(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "UlRAhZOpmILk"
},
"outputs": [],
"source": [
"# Splitting the dataset into training and testing set\n",
"train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://127.0.0.1:8080/",
"height": 332
},
"id": "iQWzZCRFmILl",
"outputId": "d2992c28-d667-4653-bbc2-9b323c82def0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets\n",
" warnings.warn(\"train_labels has been renamed targets\")\n"
]
},
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Label : 3')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQCklEQVR4nO3dfbBU9X3H8fcHuII8OIIEgygi1PhQTUi8UUdpa2Jj1WnUTEYb2jS0NWKTaOuEOLW2HR3bTh1TJbHVGKhEND5OwZEkTBolbY1jtF6VCD6hUowIgoagQCry8O0fe2gvePfsZc/ZPcv9fV4zO7t3v+fhy3I/9+zub8/+FBGY2cA3qOoGzKw9HHazRDjsZolw2M0S4bCbJcJhN0uEw54gSf8h6YvtXteq5bDvwyStkvTbVffRX5Iuk7RS0juS1kiaLWlI1X2lwmG3dvoe8LGIOAA4DvgI8GfVtpQOh30AkjRa0vclvSnpl9ntQ/dYbIqk/5L0tqQHJI3ptf7Jkh6VtFHSzySdVkZfEfFKRGzctRtgJ/BrZWzbGnPYB6ZBwHeAw4GJwP8A/7zHMl8A/gQ4BNgO3AggaQLwA+DvgDHA14AFkj7QaKeSpkna2GCZ35f0DvAWtSP7t/v9r7JCHPYBKCJ+ERELIuJXEbEJ+Hvgt/ZY7I6IWB4RW4C/AS6QNBj4PLA4IhZHxM6IeBDoAc7ux34fiYgDGyxzV/Y0/kPALcC6vf4HWlMc9gFI0nBJ35b0anYUfRg4MAvzLq/1uv0q0AWMpfZs4PzsKfzG7Eg9DRhfZo8R8RLwLHBzmdu1+vxO6MA0CzgKOCki3pA0FXia2uvkXQ7rdXsisI3aU+vXqB31L2pDn0OAKW3Yj+Ej+0DQJWlYr8sQYBS11+kbszferupjvc9LOlbScOAa4F8jYgfwXeDTkn5H0uBsm6f18QbfXpP0RUnjstvHAn8JLCm6Xesfh33ft5hasHddrga+AexP7Uj9GPDDPta7A7gNeAMYRjYEFhGvAecCVwJvUjvSX04/flck/YakzTmLnAosk7Ql63txth9rA/nLK8zS4CO7WSIcdrNEOOxmiXDYzRLR1nH2/TQ0hjGinbs0S8q7bOG92Kq+aoXCLulM4JvAYOBfIuLavOWHMYKTdHqRXZpZjsej/scWmn4an3308ibgLOBYYHr2QQkz60BFXrOfCLwcESsj4j3gHmofxjCzDlQk7BPY/WSK1dl9u5E0U1KPpJ5tbC2wOzMrokjY+3oT4H0fx4uIORHRHRHdXQwtsDszK6JI2Fez+5lThwJrirVjZq1SJOxPAEdKOkLSfsDngEXltGVmZWt66C0itku6BPg3akNv8yLi2dI6M7NSFRpnj4hdpymaWYfzx2XNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRbZ2y2Vrk5A/XLf33OflTZF/12fty6zesyJ91d9Oyg3LreaZc83Rufee77za9bXs/H9nNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0R4nH0f8PoVp+TWF3/5urq1iUNGFtr3H5yQPw7PCc1ve9qTF+fWRyx4vPmN2/sUCrukVcAmYAewPSK6y2jKzMpXxpH9ExHxVgnbMbMW8mt2s0QUDXsAP5L0pKSZfS0gaaakHkk929hacHdm1qyiT+NPjYg1ksYBD0p6ISIe7r1ARMwB5gAcoDFRcH9m1qRCR/aIWJNdrwfuB04soykzK1/TYZc0QtKoXbeBM4DlZTVmZuUq8jT+YOB+Sbu2c1dE/LCUrmw3h89fmVtfM3P/urWJHfxJirnXz86tXzjkq7n1Ufc+VmY7A17TvwoRsRL4SIm9mFkLeejNLBEOu1kiHHazRDjsZolw2M0S0cEDM7bL9rVv5NYvnHtp3dpDX6p/+ivA+AanwC7aMjy3fs6IX+XW8xyzX/62135qe2591L1N7zpJPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwOPsAcOg/PFq39p3p+d/1fOXYF3PrL2/9YP7OR+SfflvE0Tduzq3vbNmeByYf2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRHicfYBb+E+fzK3vvFS59b8e+0KZ7eyVncO6Ktv3QOQju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+zD3AHzf1pbv2nDx2VW//697bl1i8f88pe99Rfm6/ZklsfeWbLdj0gNTyyS5onab2k5b3uGyPpQUkvZdejW9ummRXVn6fxtwF7/g29AlgSEUcCS7KfzayDNQx7RDwMbNjj7nOB+dnt+cB55bZlZmVr9g26gyNiLUB2Pa7egpJmSuqR1LONrU3uzsyKavm78RExJyK6I6K7i6Gt3p2Z1dFs2NdJGg+QXa8vryUza4Vmw74ImJHdngE8UE47ZtYqDcfZJd0NnAaMlbQauAq4FrhP0oXAz4HzW9mkNW/9Jafk1jcelz8H+qLR9zfYQ+teCW54LP8760fSuu+sH4gahj0iptcpnV5yL2bWQv64rFkiHHazRDjsZolw2M0S4bCbJcKnuO4D9PHjc+vnzf9x3doXDvhG7rrDB+3XYO/VHQ8mLdzzlIzdecrmveMju1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCI+z7wN+cfzI3PrvjXqpbm34oOFlt9M2L87K7/3IGbll24OP7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIjzOvg8YMy9/2uVTDv1a3dpPLvp67rpjB49oqqd2GH/wxqpbGFB8ZDdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuFx9gFg4jWP1q19+uVZueu+e2Cxv/fR4Ddowazr6tamdOWfp2/lavg/LWmepPWSlve672pJr0taml3Obm2bZlZUf/6s3wac2cf9syNianZZXG5bZla2hmGPiIeB/Hl4zKzjFXnBdomkZ7Kn+aPrLSRppqQeST3b2Fpgd2ZWRLNh/xYwBZgKrAWur7dgRMyJiO6I6O5iaJO7M7Oimgp7RKyLiB0RsROYC5xYbltmVramwi5pfK8fPwMsr7esmXWGhuPsku4GTgPGSloNXAWcJmkqEMAq4OLWtWhFHHDXY/n1ojuQcstnTK5/rv0rF9ySu+6Xj/jP3Pqdx56eW9/x3Ircemoahj0ipvdx960t6MXMWsgflzVLhMNulgiH3SwRDrtZIhx2s0T4FFcrZND+++fWGw2v5dm0Y1j+Att3NL3tFPnIbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwuPsVsgLs3+9wRL1v+a6kdkLz8mtT1qRP5W17c5HdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sER5n76chEw6pW3vv9sG567618LDc+ribmh+LbrUhkyfl1h86c3aDLTQ/LfPk+36ZW9/Z9JbT5CO7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpaI/kzZfBhwO/BBakObcyLim5LGAPcCk6hN23xBROQPjO7D1txcf3Ljp4+5J3fdOZfUH6MH+O7rv5tbH7Fqc25959Ln6ta2f/KE3HU3HD00t/7ZP/1xbn1KV/Pj6Ed8/6Lc+tGv1P932d7rz5F9OzArIo4BTga+IulY4ApgSUQcCSzJfjazDtUw7BGxNiKeym5vAp4HJgDnAvOzxeYD57WoRzMrwV69Zpc0Cfgo8DhwcESshdofBGBc6d2ZWWn6HXZJI4EFwGUR8c5erDdTUo+knm1sbaZHMytBv8IuqYta0O+MiIXZ3eskjc/q44H1fa0bEXMiojsiurvIfzPIzFqnYdglCbgVeD4ibuhVWgTMyG7PAB4ovz0zK4siIn8BaRrwE2AZ/39W4ZXUXrffB0wEfg6cHxEb8rZ1gMbESTq9aM+V2HrWx+vWPvy3S3PXvfGQJwrte8Hm+sN+ALe+Pq1u7abJ9+Wue0SBoTOAHZF/ouktbx9et/aDUybnb3vj2031lLLHYwnvxAb1VWs4zh4RjwB9rgzsm8k1S5A/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S0XCcvUz78jh7nhVz64/BAwxf2ZVbf/bSm8tsp62eee/d3Prlk05uUycG+ePsPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwlM0l+NBF+eerDxo+PLd+1MgvFdr/iOPrf43AU933Ftr2im1bcutf/eNLc+uDearQ/q08PrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonw+exmA4jPZzczh90sFQ67WSIcdrNEOOxmiXDYzRLhsJslomHYJR0m6d8lPS/pWUl/nt1/taTXJS3NLme3vl0za1Z/vrxiOzArIp6SNAp4UtKDWW12RPxj69ozs7I0DHtErAXWZrc3SXoemNDqxsysXHv1ml3SJOCjwOPZXZdIekbSPEmj66wzU1KPpJ5tbC3WrZk1rd9hlzQSWABcFhHvAN8CpgBTqR35r+9rvYiYExHdEdHdxdDiHZtZU/oVdkld1IJ+Z0QsBIiIdRGxIyJ2AnOBE1vXppkV1Z934wXcCjwfETf0un98r8U+Aywvvz0zK0t/3o0/FfhDYJmkpdl9VwLTJU0FAlgFXNyC/sysJP15N/4RoK/zYxeX346ZtYo/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S0dYpmyW9Cbza666xwFtta2DvdGpvndoXuLdmldnb4RHxgb4KbQ37+3Yu9UREd2UN5OjU3jq1L3BvzWpXb34ab5YIh90sEVWHfU7F+8/Tqb11al/g3prVlt4qfc1uZu1T9ZHdzNrEYTdLRCVhl3SmpBclvSzpiip6qEfSKknLsmmoeyruZZ6k9ZKW97pvjKQHJb2UXfc5x15FvXXENN4504xX+thVPf1521+zSxoMrAA+BawGngCmR8RzbW2kDkmrgO6IqPwDGJJ+E9gM3B4Rx2X3XQdsiIhrsz+UoyPiLzqkt6uBzVVP453NVjS+9zTjwHnAH1HhY5fT1wW04XGr4sh+IvByRKyMiPeAe4BzK+ij40XEw8CGPe4+F5if3Z5P7Zel7er01hEiYm1EPJXd3gTsmma80scup6+2qCLsE4DXev28ms6a7z2AH0l6UtLMqpvpw8ERsRZqvzzAuIr72VPDabzbaY9pxjvmsWtm+vOiqgh7X1NJddL436kR8THgLOAr2dNV659+TePdLn1MM94Rmp3+vKgqwr4aOKzXz4cCayroo08RsSa7Xg/cT+dNRb1u1wy62fX6ivv5P500jXdf04zTAY9dldOfVxH2J4AjJR0haT/gc8CiCvp4H0kjsjdOkDQCOIPOm4p6ETAjuz0DeKDCXnbTKdN415tmnIofu8qnP4+Itl+As6m9I/8K8FdV9FCnr8nAz7LLs1X3BtxN7WndNmrPiC4EDgKWAC9l12M6qLc7gGXAM9SCNb6i3qZRe2n4DLA0u5xd9WOX01dbHjd/XNYsEf4EnVkiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiP8Fvji1zrt7lZQAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Visualizing a sample from dataset\n",
"plt.imshow(train_dataset.dataset.data[10])\n",
"plt.title(\"Label : \" + str(train_dataset.dataset.train_labels[10].item()))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "9WhgIZF0mILl"
},
"outputs": [],
"source": [
"# Creating a DataLoader for training and testing\n",
"train = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
"test = DataLoader(test_dataset, batch_size=1, shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pWhZXmm6mILm"
},
"source": [
"Define a network with the following architecture:\n",
"\n",
"Conv2d (input channels=1, output channels = 15,kernel size = 5)\n",
"$\\rightarrow$\n",
"MaxPool (kernel size = 2)\n",
"$\\rightarrow$\n",
"ReLU\n",
"$\\rightarrow$\n",
"Conv2d (input channels=15, output channels = 30,kernel size = 5)\n",
"$\\rightarrow$\n",
"Dropout2d (p = 0.5)\n",
"$\\rightarrow$\n",
"MaxPool (kernel size = 2)\n",
"$\\rightarrow$\n",
"ReLU\n",
"$\\rightarrow$\n",
"Linear(input dimension = 480, hidden units = 64)\n",
"$\\rightarrow$\n",
"ReLU\n",
"$\\rightarrow$\n",
"Dropout (p=0.5)\n",
"$\\rightarrow$\n",
"Linear(input dimension = 64, hidden units = 10)\n",
"$\\rightarrow$\n",
"LogSoftMax"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "_-RTFbeKmILn"
},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.cnn = nn.Sequential(\n",
" nn.Conv2d(1, 15, kernel_size=5),\n",
" nn.MaxPool2d(2, 2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(15, 30, kernel_size=5),\n",
" nn.Dropout2d(0.5),\n",
" nn.MaxPool2d(2, 2),\n",
" nn.ReLU(),\n",
" nn.Flatten(),\n",
" nn.Linear(480, 64),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(64, 10),\n",
" nn.LogSoftmax(dim=1),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.cnn(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jzwFIYidmILo"
},
"source": [
"Train the network you defined in the previous question on MNIST, using the optimizer and the number of training epochs you deem appropriate. Use a cross-entropy loss. Each epoch test your model on the testing dataset and print the value of the accuracy that you achieve. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://127.0.0.1:8080/"
},
"id": "-hdIBNSzmILp",
"outputId": "101815c2-c1c6-441b-bd05-ce08b702c044"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891\n",
"Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743\n",
"Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391\n",
"Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097\n",
"Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349\n",
"Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978\n",
"Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738\n",
"Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843\n",
"Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386\n",
"Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913\n"
]
}
],
"source": [
"batch_size = 32\n",
"\n",
"model = CNN()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"# Running the model on GPU if available\n",
"## Refer pytorch documentation for more details about copying model and data onto the device\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device)\n",
"\n",
"cost = []\n",
"epochs = 10\n",
"\n",
"# Training the model\n",
"for epoch in range(epochs):\n",
"\n",
" loss_epoch = []\n",
" train_acc = []\n",
"\n",
" for x, y in train:\n",
"\n",
" # Predicting the output\n",
" y_pred = model(x.to(device))\n",
"\n",
" # Converting the predicted output from one hot encoding to a single number\n",
" _, t_preds = torch.max(y_pred, dim=1)\n",
"\n",
" # Calculating the training accuracy\n",
" train_acc.append(\n",
" torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))\n",
" )\n",
"\n",
" # Calculating the loss\n",
" loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))\n",
"\n",
" # Backpropagation\n",
"\n",
" # Zeroing the gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # Calculating the gradients\n",
" loss.backward()\n",
"\n",
" # Updating the weights\n",
" optimizer.step()\n",
"\n",
" # Appending the loss of each batch to the epoch loss\n",
" loss_epoch.append(loss.item())\n",
"\n",
" # Calculating test accuracy\n",
" with torch.no_grad():\n",
" if epoch % 1 == 0:\n",
" test_acc = []\n",
" for x, y in test:\n",
" y_pred = model(x.to(device))\n",
" _, t_preds = torch.max(y_pred, dim=1)\n",
" test_acc.append(\n",
" torch.tensor(\n",
" torch.sum(t_preds == y.to(device)).item() / len(t_preds)\n",
" )\n",
" )\n",
"\n",
" print(\n",
" \"Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}\".format(\n",
" epoch,\n",
" sum(loss_epoch) / len(loss_epoch),\n",
" sum(train_acc) / len(train_acc),\n",
" sum(test_acc) / len(test_acc),\n",
" )\n",
" )\n",
"\n",
" cost.append(sum(loss_epoch) / len(loss_epoch))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://127.0.0.1:8080/",
"height": 337
},
"id": "J7ecIEspmILq",
"outputId": "5ec70ca4-45b0-4cf3-fedc-4f57daf7a8fa"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fa94052aa00>]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAEvCAYAAAAErSPcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxBUlEQVR4nO3deXCc933n+c+3LwAN4gYIgAd4AeIlS7IEU6clkbREKfFG8sQztuKJZ2M7smxd2ZqtjXcytbW1Lk/NTKVmLduyZUWWEydeaxzZcuSTik5akmUR1GVSFCmS4gGSIEEQIG709ds/utFsgk2iQQJ4uhvvVxUL/TzP72l8uwqi+MH3+f1+5pwTAAAAACD/+bwuAAAAAACQGwIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFIiA1wVkU19f75YuXep1GQAAAADgiW3btp1wzjVMPJ+XAW7p0qXq6OjwugwAAAAA8ISZHch2nkcoAQAAAKBAEOAAAAAAoEAQ4AAAAACgQBDgAAAAAKBA5BTgzOw2M9tlZnvM7CvnGfcRM4ub2Sczzu03sz+Y2VtmxsokAAAAAHCBJl2F0sz8kh6WdIukTklbzexp59y7Wcb9N0mbs7zNeufciWmoFwAAAADmrFw6cOsk7XHO7XPORSQ9IemOLOPul/QTScensT4AAAAAQEouAW6hpEMZx52pc2lmtlDSJyQ9kuV+J+kZM9tmZndfaKEAAAAAMNflspG3ZTnnJhx/XdJfO+fiZmcNv945d8TM5kv6VzN7zzm35axvkgx3d0tSS0tLDmUBAAAAwNySSweuU9LijONFko5MGNMu6Qkz2y/pk5K+bWZ3SpJz7kjq63FJTyn5SOZZnHOPOufanXPtDQ0NU/kMM240Gtc/vLpfPYNjXpcCAAAAYA7LJcBtldRmZsvMLCTp05KezhzgnFvmnFvqnFsq6UlJX3bO/czMys2sQpLMrFzSrZK2T+snmAWdvSP6v3++Q4+9/IHXpQAAAACYwyYNcM65mKT7lFxdcqekHzvndpjZPWZ2zyS3N0p62czelvS6pF86535zsUXPttb58/TxyxboB6/uV+9QxOtyAAAAAMxR5tzE6Wzea29vdx0d+bVl3O5jA7r1/92i+ze06j/eutLrcgAAAAAUMTPb5pxrn3g+p428IV3SWKE/+lCT/v6V/To1HPW6HAAAAABzEAFuCu7f0KaBsZgef4W5cAAAAABmHwFuClY3V2rT2kY9/soH6h+lCwcAAABgdhHgpuj+DW0aGI3p71/Z73UpAAAAAOYYAtwUXbqwSh9b3ajvvfyBBujCAQAAAJhFBLgL8MDGVp0aieoHvzvgdSkAAAAA5hAC3AW4bFG11q9s0GO/3aehsZjX5QAAAACYIwhwF+iBjW3qHY7qH1+jCwcAAABgdhDgLtCHW2p04yUN+rst+zQcoQsHAAAAYOYR4C7Cgxtb1TMU0Q9fO+h1KQAAAADmAALcRbhqSa1uaK3Xd7fs00gk7nU5AAAAAIocAe4iPbCxTScGx/Sj1+nCAQAAAJhZBLiLtG5Zra5ZXqtHXtqr0ShdOAAAAAAzhwA3DR7ceImOD4zpf2495HUpAAAAAIoYAW4aXLO8VuuW1uo7L+7VWIwuHAAAAICZQYCbBmamBza2qat/VP/c0el1OQAAAACKFAFumlzfWqerltToOy/uVSSW8LocAAAAAEWIADdNxrtwh/tG9JM36MIBAAAAmH4EuGl0Y1u9Ll9crYdf2KNonC4cAAAAgOlFgJtGZqYHN7aqs3dET71x2OtyAAAAABQZAtw0W79yvj60sErfemGPYnThAAAAAEwjAtw0G58Ld/DksP7lrSNelwMAAACgiBDgZsDHVs/XmuZKunAAAAAAphUBbgaMd+E+ODGkX7xz1OtyAAAAABQJAtwMuXVNo1Y1Veibz7+veMJ5XQ4AAACAIkCAmyE+n+n+DW3a2z2kX/6BLhwAAACAi5dTgDOz28xsl5ntMbOvnGfcR8wsbmafnOq9xej2S5vUNn+evvnc+0rQhQMAAABwkSYNcGbml/SwpNslrZF0l5mtOce4/yZp81TvLVY+n+n+jW16//igfrOjy+tyAAAAABS4XDpw6yTtcc7tc85FJD0h6Y4s4+6X9BNJxy/g3qL1xx9q1oqGcn2DLhwAAACAi5RLgFso6VDGcWfqXJqZLZT0CUmPTPXeYudPzYV7r2tAz7x7zOtyAAAAABSwXAKcZTk3sZX0dUl/7ZyLX8C9yYFmd5tZh5l1dHd351BW4fj4Zc1aVp/swjlHFw4AAADAhcklwHVKWpxxvEjSkQlj2iU9YWb7JX1S0rfN7M4c75UkOecedc61O+faGxoacqu+QAT8Pt27vlXvHu3XszuPT34DAAAAAGSRS4DbKqnNzJaZWUjSpyU9nTnAObfMObfUObdU0pOSvuyc+1ku984Vd16xQC21YbpwAAAAAC7YpAHOOReTdJ+Sq0vulPRj59wOM7vHzO65kHsvvuzCE/D7dN/6Vv3h8Cm9uKu4HhEFAAAAMDssH7tB7e3trqOjw+sypl00ntD6v31R9fNK9NSXr5NZtimCAAAAAOY6M9vmnGufeD6njbwxPYKpuXBvHerTlvdPeF0OAAAAgAJDgJtlf3rlIi2oKtVDz+5mLhwAAACAKSHAzbJQwKcvrW/VGwf79OreHq/LAQAAAFBACHAe+Hfti9RUWaqHnmVFSgAAAAC5I8B5oCTg15duXqHX95/Ua/tOel0OAAAAgAJBgPPIpz6yWPMrSvSN5973uhQAAAAABYIA55HSoF9fvGmFfrevR69/QBcOAAAAwOQIcB76s3Utqp9HFw4AAABAbghwHioL+fXFG5fr5T0ntO0AXTgAAAAA50eA89hnrmlRXXlIDz23x+tSAAAAAOQ5ApzHwqGA/vLG5dqyu1tvHuz1uhwAAAAAeYwAlwf+/JolqgkH9c3n6cIBAAAAODcCXB4oLwnoCx9druffO653Ovu8LgcAAABAniLA5YnPXrtEVWVBfYO5cAAAAADOgQCXJypKg/r8Dcv07M5j2n74lNflAAAAAMhDBLg88h+uW6qK0oC++Tz7wgEAAAA4GwEuj1SVBfW565dp845j2nm03+tyAAAAAOQZAlye+dz1yzSvJKBvsSIlAAAAgAkIcHmmKhzU/3rdUv1q+1HtPjbgdTkAAAAA8ggBLg99/oZlCgf97AsHAAAA4AwEuDxUUx7SZ69bql+8c0R7jg96XQ4AAACAPEGAy1NfuGGZSgN+PfwCXTgAAAAASQS4PFU3r0SfvXaJ/uWtw9rXTRcOAAAAAAEur33ho8sVCvj08At7vS4FAAAAQB4gwOWxhooSfebqJfrZW4d1oGfI63IAAAAAeIwAl+e+eONyBXzGXDgAAAAAuQU4M7vNzHaZ2R4z+0qW63eY2Ttm9paZdZjZDRnX9pvZH8avTWfxc8H8ylLdta5FP33jsA6dHPa6HAAAAAAemjTAmZlf0sOSbpe0RtJdZrZmwrDnJF3unLtC0uckPTbh+nrn3BXOufaLL3nuueemFfKZ6dsvMhcOAAAAmMty6cCtk7THObfPOReR9ISkOzIHOOcGnXMudVguyQnTpqmqVJ/6yGI9ue2QDveNeF0OAAAAAI/kEuAWSjqUcdyZOncGM/uEmb0n6ZdKduHGOUnPmNk2M7v7Yoqdy7508wpJ0ndeZC4cAAAAMFflEuAsy7mzOmzOuaecc6sk3SnpqxmXrnfOXankI5j3mtmNWb+J2d2p+XMd3d3dOZQ1tyyoLtO/bV+sH2/t1NFTdOEAAACAuSiXANcpaXHG8SJJR8412Dm3RdIKM6tPHR9JfT0u6SklH8nMdt+jzrl251x7Q0NDjuXPLV+6aYUSzukR5sIBAAAAc1IuAW6rpDYzW2ZmIUmflvR05gAzazUzS72+UlJIUo+ZlZtZRep8uaRbJW2fzg8wlyyuDeuTVy3Sj7Ye0rH+Ua/LAQAAADDLJg1wzrmYpPskbZa0U9KPnXM7zOweM7snNexPJW03s7eUXLHyU6lFTRolvWxmb0t6XdIvnXO/mYHPMWd8+eZWxRNO331pn9elAAAAAJhldnrxyPzR3t7uOjrYMu5c/vd/fls/f/uIfvvX6zW/otTrcgAAAABMMzPblm0btpw28kZ+uW99q6LxhP5uC104AAAAYC4hwBWgpfXluvOKhfqn1w7qxOCY1+UAAAAAmCUEuAJ174ZWjcXi+rvf0oUDAAAA5goCXIFa0TBP/8vlC/SPvzugk0MRr8sBAAAAMAsIcAXsvvWtGonG9b2X6cIBAAAAcwEBroC1NVbojz7UrH949YD6hunCAQAAAMWOAFfgHtjQpsGxmB5/+QOvSwEAAAAwwwhwBW5lU4Vuv7RJ339lv06NRL0uBwAAAMAMIsAVgfs3tGlgLKa/f2W/16UAAAAAmEEEuCKwZkGlbl3TqO+9vE/9o3ThAAAAgGJFgCsSD2xsU/9oTD94db/XpQAAAACYIQS4InHpwiptXDVfj738gQbHYl6XAwAAAGAGEOCKyAMb29Q3HNUPfrff61IAAAAAzAACXBG5fHG1bl7ZoMd++4GG6MIBAAAARYcAV2Tu39Cmk0MR/fD3B7wuBQAAAMA0I8AVmauW1OijbfV6dMs+jUTiXpcDAAAAYBoR4IrQgxvbdGKQLhwAAABQbAhwRah9aa2uW1Gn727Zp9EoXTgAAACgWBDgitQDG9vUPTCmH71+0OtSAAAAAEwTAlyRumZ5na5eVqtHXtpLFw4AAAAoEgS4IvbgxjYd6x/TP3cc8roUAAAAANOAAFfErl1Rp/YlNfr2i3s1FqMLBwAAABQ6AlwRMzM9+LE2HT01qie3dXpdDgAAAICLRIArcje01uvDLdX69gt7FYklvC4HAAAAwEUgwBU5M9MDG9t0uG9ET71JFw4AAAAoZAS4OeDmSxp0+aIqfeuFPYrG6cIBAAAAhSqnAGdmt5nZLjPbY2ZfyXL9DjN7x8zeMrMOM7sh13sx88a7cIdOjuhnbx72uhwAAAAAF2jSAGdmfkkPS7pd0hpJd5nZmgnDnpN0uXPuCkmfk/TYFO7FLNiwar4uXViph1/YoxhdOAAAAKAg5dKBWydpj3Nun3MuIukJSXdkDnDODTrnXOqwXJLL9V7MDjPTAxvatL9nWE+/fcTrcgAAAABcgFwC3EJJmTtBd6bOncHMPmFm70n6pZJduJzvxey4ZU2jVjdX6lvP71E84Sa/AQAAAEBeySXAWZZzZ/3r3zn3lHNulaQ7JX11KvdKkpndnZo/19Hd3Z1DWZiqZBeuVftODOkX79CFAwAAAApNLgGuU9LijONFks75r3/n3BZJK8ysfir3Oucedc61O+faGxoacigLF2LT2iatbKzQN+nCAQAAAAUnlwC3VVKbmS0zs5CkT0t6OnOAmbWamaVeXykpJKknl3sxu3w+0/0bW7Xn+KB+vf2o1+UAAAAAmIJJA5xzLibpPkmbJe2U9GPn3A4zu8fM7kkN+1NJ283sLSVXnfyUS8p67wx8DkzB7Zc2q3X+PH3zuT1K0IUDAAAACoadXjwyf7S3t7uOjg6vyyhq//LWYT34xFv6zmeu1O0fava6HAAAAAAZzGybc6594vmcNvJG8fn4ZQu0vL5cDz33Pl04AAAAoEAQ4OYov89034ZWvdc1oGd3HvO6HAAAAAA5IMDNYX9y+QItrQvroefeVz4+SgsAAADgTAS4OSzg9+ne9a3acaRfz7933OtyAAAAAEyCADfH3fnhhVpcW6Zv0IUDAAAA8h4Bbo4L+n269+ZWvd15Si/u7va6HAAAAADnQYCD/s2Vi7SwukwPPUsXDgAAAMhnBDgoFPDpy+tX6K1DfXp5zwmvywEAAABwDgQ4SJI+edUiNVeV0oUDAAAA8hgBDpKkkoBfX755hToO9Op3e3u8LgcAAABAFgQ4pP3b9sVqrCzRQ8+973UpAAAAALIgwCGtNOjXPTet0O8/OKnX9tGFAwAAAPINAQ5nuGtdixoqSvQNunAAAABA3iHA4QylQb++eONyvbq3R1v3n/S6HAAAAAAZCHA4y2euXqL6eSG6cAAAAECeIcDhLGUhv+6+cbl++/4JvXGw1+tyAAAAAKQQ4JDVZ65eotpyunAAAABAPiHAIavykoC+8NFlenFXt94+1Od1OQAAAABEgMN5fPbapaoOB+nCAQAAAHmCAIdzmlcS0BduWKbn3juu7YdPeV0OAAAAMOcR4HBen71uqSpLA3ThAAAAgDxAgMN5VZYG9bkblumZd4/p3SP9XpcDAAAAzGkEOEzqL65fpoqSgL75PF04AAAAwEsEOEyqqiyov7h+qX69vUu7uga8LgcAAACYswhwyMnnblimeSUBfYMuHAAAAOAZAhxyUh0O6T9ct0S/+sNRvX+MLhwAAADghZwCnJndZma7zGyPmX0ly/XPmNk7qT+vmtnlGdf2m9kfzOwtM+uYzuIxuz5/w3KVBf361gt7vC4FAAAAmJMmDXBm5pf0sKTbJa2RdJeZrZkw7ANJNznnLpP0VUmPTri+3jl3hXOufRpqhkdqy0P682uX6OdvH9He7kGvywEAAADmnFw6cOsk7XHO7XPORSQ9IemOzAHOuVedc72pw9ckLZreMpEv/vKjy1US8Ovh5+nCAQAAALMtlwC3UNKhjOPO1Llz+bykX2ccO0nPmNk2M7t76iUin9TPK9G/v6ZFP3vrsPafGPK6HAAAAGBOySXAWZZzLutAs/VKBri/zjh9vXPuSiUfwbzXzG48x713m1mHmXV0d3fnUBa88pc3LlfQ79PDzIUDAAAAZlUuAa5T0uKM40WSjkwcZGaXSXpM0h3OuZ7x8865I6mvxyU9peQjmWdxzj3qnGt3zrU3NDTk/gkw6+ZXlOozVy/RT988rIM9w16XAwAAAMwZuQS4rZLazGyZmYUkfVrS05kDzKxF0k8l/blzbnfG+XIzqxh/LelWSdunq3h454s3LZffZ/r2i3ThAAAAgNkyaYBzzsUk3Sdps6Sdkn7snNthZveY2T2pYf+XpDpJ356wXUCjpJfN7G1Jr0v6pXPuN9P+KTDrGitLdddHFuvJbZ3q7KULBwAAAMwGcy7rdDZPtbe3u44OtozLd0dPjeim//6iPtm+SP/lEx/yuhwAAACgaJjZtmzbsOW0kTeQTXNVmf7dRxbpnzsO6UjfiNflAAAAAEWPAIeL8qWbWyVJj7y01+NKAAAAgOJHgMNFWVhdpk9etUhPvH5IXadGvS4HAAAAKGoEOFy0L9/cqoRzdOEAAACAGUaAw0VbXBvWv7lyoX70+kEd76cLBwAAAMwUAhymxb3rWxVLOH13yz6vSwEAAACKFgEO02JJXbnuvGKhfvj7A+oeGPO6HAAAAKAoEeAwbe5dv0KRWEKP/ZYuHAAAADATCHCYNssb5ulPLl+gH/zugHoG6cIBAAAA040Ah2l134Y2jcbieuzlD7wuBQAAACg6BDhMq9b58/TxyxboB6/uV+9QxOtyAAAAgKJCgMO0u39Dq4ajcX2PLhwAAAAwrQhwmHaXNFbojy5t1t+/ul+nhqNelwMAAAAUDQIcZsR9G1o1OBbT46/QhQMAAACmCwEOM2J1c6U2rW3U4698oFMjdOEAAACA6UCAw4x5YGObBkZj+odX93tdCgAAAFAUCHCYMWsXVOljqxv1vZc/0MAoXTgAAADgYhHgMKMe3NimUyNR/eB3B7wuBQAAACh4BDjMqA8tqtKGVfP12G/3aWgs5nU5AAAAQEEjwGHG3b+hVb3DUf3ja3ThAAAAgItBgMOM+3BLjW68pEF/t2WfhiN04QAAAIALRYDDrHhwY5t6hiL64WsHvS4FAAAAKFgEOMyKq5bU6IbWen13yz6NROJelwMAAAAUJAIcZs0DG9t0YnBMP3qdLhwAAABwIQhwmDXrltXqmuW1euSlvRqN0oUDAAAApooAh1n14MZLdHxgTP9z6yGvSwEAAAAKTk4BzsxuM7NdZrbHzL6S5fpnzOyd1J9XzezyXO/F3HLN8lqtW1qr77y4V2MxunAAAADAVEwa4MzML+lhSbdLWiPpLjNbM2HYB5Jucs5dJumrkh6dwr2YQ8xMD2xsU1f/qH7c0el1OQAAAEBByaUDt07SHufcPudcRNITku7IHOCce9U515s6fE3SolzvxdxzfWudrlpSo6/98l196Z+26V/eOqz+0ajXZQEAAAB5L5DDmIWSMicsdUq6+jzjPy/p1xd4L+YAM9PXP3WFHnlpr55595h+vb1LQb/p2hX12rS2UbesadT8ilKvywQAAADyTi4BzrKcc1kHmq1XMsDdcAH33i3pbklqaWnJoSwUssW1YX3tEx/SV++4VG8e6tMzO7q0eUeX/uap7frPP9uuq1pqtGltkzatbVJLXdjrcgEAAIC8kEuA65S0OON4kaQjEweZ2WWSHpN0u3OuZyr3SpJz7lGl5s61t7dnDXkoPj6f6aolNbpqSY2+cvsq7To2oM3bj2nzji597Vc79bVf7dSqpop0mFvdXCGzbL8XAAAAAIqfOXf+rGRmAUm7JW2UdFjSVkl/5pzbkTGmRdLzkj7rnHt1Kvdm097e7jo6Oi7oA6F4HDo5rM07uvTMjmPaeuCknJNaasPatLZRm9Y26cqWGvl8hDkAAAAUHzPb5pxrP+v8ZAEudfMfSfq6JL+kx51zXzOzeyTJOfeImT0m6U8lHUjdEhv/Ztnunez7EeAwUffAmJ7dmezMvbqnR5F4QvXzSnTLmkZtWtuo61bUKxRgW0MAAAAUh4sKcLONAIfzGRiN6oVd3dq8o0svvndcQ5G4KkoD2rBqvjatbdJNlzSovCSXp4MBAACA/ESAQ1Eajcb1yp4T2ryjS8/uPK6TQxGVBHz6aFu9bl3bpI+tblRtecjrMgEAAIApOVeAo02BglYa9Gvj6kZtXN2oWDyhjgO96Xlzz+48Lr/PtG5prTatbdSta5u0oLrM65IBAACAC0YHDkXJOafth/u1ObU9wfvHByVJly2qSq9o2Tp/nsdVAgAAANnxCCXmtH3dg9q8I7kIyluH+iRJKxrK02HuskVVbE8AAACAvEGAA1K6To3qmXeTnbnX9p1UPOHUXFWqW9c0atOlTVq3tFYBPytaAgAAwDsEOCCLvuGIntt5XJt3dOml3d0aiyVUEw5q4+rkXnMfbatXadDvdZkAAACYYwhwwCSGIzFt2d2tzTuO6bmdx9Q/GlM45NdNlzTotkubtH7VfFWWBr0uEwAAAHMAq1ACkwiHArrt0mbddmmzovGEXtvXk17R8tfbuxT0m65dUa9Naxt1y5pGza8o9bpkAAAAzDF04IBJJBJObx7q0zOpFS339wzLTLqypUa3pRZBaakLe10mAAAAigiPUALTwDmn3ccGtXlHl36zvUvvHu2XJK1qqkivaLm6uYIVLQEAAHBRCHDADDh0cjj9mOXWAyflnNRSG9amtclFUK5sqZHPR5gDAADA1BDggBl2YnBMz76b3GvulT09isQTqp9XolvWNGrT2kZdt6JeoQDbEwAAAGByBDhgFg2MRvXCrm5t3tGlF987rqFIXBWlAW1YNV+b1jbppksaVF7CGkIAAADIjgAHeGQ0Gtere09o8/Zj+tedx3RyKKJQwKcb2+p169omfWx1o2rLQ16XCQAAgDxCgAPyQCyeUMeB3vS8ucN9I/L7TOuW1mrT2kbdurZJC6rLvC4TAAAAHiPAAXnGOacdR/rTK1q+f3xQknTZoqrUipaNap1f4XGVAAAA8AIBDshz+7oHtXlHchGUtw71SZJWNJSntye4bFEV2xMAAADMEQQ4oIB0nRrVv77bpd/s6NJr+04qnnBqrirVrWuS2xOsW1argJ8VLQEAAIoVAQ4oUH3DET2387g27+jSlve7NRpNqCYc1MbVyTD30bZ6lQb9XpcJAACAaUSAA4rAcCSmLbtP6JkdXXp25zH1j8YUDvl10yUN2rS2Se1La7SwuoxHLQEAAArcuQIcG1EBBSQcCui2S5t026VNisYTem1fT3pFy19v75IkVZQEtLKpQiubKrSquVKrUq8rS4MeVw8AAICLRQcOKAKJhNMfDp/S9iOntKtrQO91Dei9o/3qH42lxyysLkuGulSgW91cqWX15Qoylw4AACDv0IEDipjPZ7p8cbUuX1ydPuecU1f/qN47mgp0Xf3a1TWg377frWg8+YubkN+n5Q3lWt1cmQ53q5oq1VhZwmOYAAAAeYgABxQpM1NzVZmaq8q0ftX89PlILKF9Jwa1q2tAO48OaFdXv17b16On3jycHlMdDmplYyrQpcLdysYKlZfwVwYAAICX+NcYMMeEAj6taqrUqqZK3XHF6fOnhqPadSzZqRt/BPMnbxzW4NiB9JiW2nDy8cumCq1sqtSq5gotrSuX30e3DgAAYDYQ4ABIkqrCQa1bVqt1y2rT55xz6uwd0XtdyU7dzq4B7eoa0HM7jymRmj5bEvDpksaKMx7BXNlUoYaKEo8+CQAAQPHKKcCZ2W2SHpLkl/SYc+6/Tri+StL3JV0p6W+cc3+bcW2/pAFJcUmxbBPxAOQnM9Pi2rAW14Z1y5rG9PnRaFx7jg+mFkxJduxe2t2tJ7d1psfUzwulHr1MdupWNVWobX6FykLsWQcAAHChJg1wZuaX9LCkWyR1StpqZk87597NGHZS0gOS7jzH26x3zp24yFoB5InSoF+XLqzSpQurzjjfMziWXgVzPNz96PWDGonGJUk+k5bWlWtVczLYJVfDrNDimrB8PIYJAAAwqVw6cOsk7XHO7ZMkM3tC0h2S0gHOOXdc0nEz++MZqRJAQaibV6LrWkt0XWt9+lw84XTw5HDyEcyjA+nFU369vUvju5iEQ35d0nh6i4PkHL0K1ZSHPPokAAAA+SmXALdQ0qGM405JV0/hezhJz5iZk/Rd59yjU7gXQIHz+0zL6su1rL5ct13anD4/HInp/WOD6Ucwd3UN6Jl3j+mJraf/ummsLNHKpsrUoinJP63z56kkwGOYAABgbsolwGV7rmkqu39f75w7YmbzJf2rmb3nnNty1jcxu1vS3ZLU0tIyhbcHUIjCoUDWveu6B8f03ninLrV33fdf7VEklpCUDITL68u1qjnZpVvZWKFVzRVaWF3G3nUAAKDo5RLgOiUtzjheJOlIrt/AOXck9fW4mT2l5COZZwW4VGfuUUlqb2+fSkAEUCTMTPMrSjW/olQ3XtKQPh+LJ7S/Zyi1vUFyjt2bB3v187dP/1VUURJId+nS4a6pQpWlQS8+CgAAwIzIJcBtldRmZsskHZb0aUl/lsubm1m5JJ9zbiD1+lZJ/8+FFgtgbgr4fWqdX6HW+RX6+GWnzw+MRrU79RjmrlS4+/nbR/TD3x9Mj1lYXZbe4mB8ft3yhnIF/T4PPgkAAMDFmTTAOediZnafpM1KbiPwuHNuh5ndk7r+iJk1SeqQVCkpYWZ/JWmNpHpJT6UeawpI+v+cc7+ZkU8CYM6pKA3qqiU1umpJTfqcc05HT42mV8McD3dbdncrltq8Lug3rWiYp9XNlWfsX9dYWcJjmAAAIK+Zc/n3tGJ7e7vr6OjwugwARSQSS2jficH0Kpi7UounHD01mh5TVRZMhbkKrWyq1MqmeVpSV6668hDBDgAAzCoz25ZtD+2cNvIGgEIXCvhS2xNU6o4rTp8/NRxNdumOnQ52T27r1FAknh5THvKrpa5cS+vCaqkLa0ltuZbUhbWkLqzmqjL52cMOAADMEgIcgDmtKhzU1cvrdPXyuvS5RMLpcN+I3j8+oAM9wzrQM5zcy+7YgJ7beVyReCI9NuT3aVFNWSrQlaulNpx+vbi2jC0PAADAtCLAAcAEPp9pcW1Yi2vDZ12LJ5y6+kd1oGcoI9wNaf+JYW3d36vBsVh6rJnUXFmqlrqwltaVn9G9a6kLs0ImAACYMgIcAEyB32daWF2mhdVlum7Fmdecczo5FNGBk8PpgHewZ1gHTg7r2Z3HdGIwcsb42vLQGR27JRmv6+cx7w4AAJyNAAcA08TMVDevRHXzSnRlS81Z1wfHYslA1zOUCnnJ7l3H/uSedomMNaXCIX863E3s3i2oZt4dAABzFQEOAGbJvJKA1iyo1JoFlWddi8QS6uxNdusOnEgGvIM9w9rbPaQXdnUrEjs97y7oNy2qCWft3i2uDas0yLw7AACKFQEOAPJAKODT8oZ5Wt4wT1p55rVEet5dar5d+tHMIb1xoFcDGfPuJKm5qvTMcJfq3rXUhVVVxrw7AAAKGQEOAPKcz2daUF2mBdVlunZF3RnXnHPqHY7qQM+QDp4c1v4TyWB3sGdYL+zqVvdA5xnja8JBtUyYb5cMeGE1VLCROQAA+Y4ABwAFzMxUWx5SbXlIH84y725oLKaDqfl243PvDvYM681DvfrFO2fOuysL+pMrZGbp3i2oLlXA75vFTwYAALIhwAFAESsvCWh1c6VWN2efd3e4b+SM7t3Bk0P64MSQXtrdrbGMeXcBn2lRTVnW7l0L8+4AAJg1BDgAmKNCAZ+W1ZdrWX35WdcSCadjA6PprRD2Z3bvDvZqYPTMeXeNlSVnbYUw3r2rCjPvDgCA6UKAAwCcxeczNVeVqbmqTNcsP3veXd9wNL3fXTLgJbt3L+3u1vGBsTPGV5UFtbQurJa6cs2vKFF1WVDV4aCqwyFVh4OqCYdUlTo3ryTAPDwAAM6DAAcAmBIzU015SDXlIV2xuPqs68OR0/Puxrt3B08O6+1DfeoZHNNQJH7O9w747HS4ywx6E0JfdVnqayoAhkN+gh8AYE4gwAEAplU4FNCqpkqtajp73p2UnHvXNxLRqeGoeoej6huOqG8k9TV17tRI8vXhvlG9e6RffSNRDZ8n+AX9pqqykGpSoS7zdWboqwkHVZU6VxMOqixI8AMAFBYCHABgVoUCPs2vKNX8itIp3Tcajat/ZPLQ1zscUWfvsHYcSb4ejSbO+Z4hvy/dycvs6mUNfWUh1ZQnv5aFWLQFAOANAhwAoCCUBv0qDfo1v3Lqwa9vOKq+VMDLDH2nO4HJcwdPDuvtzoh6h6OKxM4d/EoCvrNDX1lI1eWnz9WMdwIzzrFaJwDgYhHgAABFrTToV1OVX01VUwt+I5F4OvT1Dmc88jkh9PWNRPXBiSH1DfepbziqSPzcwa806Dsr9NWUJ4NeZugbn9s3Pq4kQPADACQR4AAAyKIs5FdZKLkSZ66ccxpJdfwmhr6J3b9TIxHt7R5U74Hk62jcnfN9y4L+cy7oUlEaUGVpUJVlp19XlQVUURpUZWlQpUEf8/wAoIgQ4AAAmCZmpnAooHAooAXVUwt+w5H46a7eeUJf33BUu48N6NRIclwsce7gJyVX9qwsC6qyNBXqygKqKEl+rSwNnj5XmhyTGQQrS4OaVxqQ30cABIB8QYADAMBjZqbykoDKSwJaVJP7fc45jUYT6h+Nqn8kqv7RmPpHoxoYjaWOT78eSF3rH4nqeP9Y+tr5VvccV1ESSIa6suwdv/Frma9PjwvwCCgATCMCHAAABcrMUo96+tU4xcVdxkXjCQ2mw11MA6PJ4NefDoGpcxnXjp4a1e7jA+lzkzQBFQr4Uh29gCpS3cDK83X+JoRDNngHgNMIcAAAzGFBvy+9MfuFcM5pKBJPh7xkZy8z8MWydgeP9I2kr42dZ8VPSfKZVFF6/o7f2aHwzG5g0O+7oM8HAPmGAAcAAC6YmWleSbJL1lx1Ye8xFotrYDQ26aOf6dejMR06OZw+HhyLyU3SBSwL+rN0/M6cG1hZGlQ45E9tWeFTacCvkqBfJQHf6XOp7SxKAz4FCIUAPECAAwAAnioJ+FUyz6/6eSUXdH8i4TQYyQh8Zzz6eTr4DWQ8Kto7FNGBnmENjEZ1aiR63lVAzyXgM5VmBLySVOg7I+gFfSoJZH7NuH7WfanXQX/6fUoyxpUG/SwoA4AABwAACpvPZ+lVMy+Ec05jsYT6R6IajsQ1GotrLJrQaDSu0VjqazR1LhZPHSc0Fkt+HT9O3nf63MBoLPUep8+NRRPn3StwMkG/pTuDyVB4ZlicGATT188IluPXzh06SzOu+QiNQF4hwAEAgDnNzNLhZTbEEy4d/s4MgWcHwWwhMR0goxlhMzWubzh6VsAci8UvqMM4LuT3ne4MZobEjEdLS7KExJKAX+GQX9XhkGrLg6otL1FtOKTaeSGVh/wsTANcIAIcAADALPL7xvcLnL3vGYsnNBY7u6s4HgRPX5sQGDNCYrawORSJqWdownukvkf8PMuThvw+1YyHuvKgasIh1aUW0xn/Oh72asMhVYdDCgWYcwhIOQY4M7tN0kOS/JIec8791wnXV0n6vqQrJf2Nc+5vc70XAAAAMyvgTy66Ul4ye7+7j8YTGo7E1TccUc9QRL1DEZ0c/zN85vHh3lM6ORRR/2jsnO9XURpQbXko+SecJeyNv079qSxl+wkUp0n/KzYzv6SHJd0iqVPSVjN72jn3bsawk5IekHTnBdwLAACAIhP0+1RV5lNVWVBL6spzuicaT6h3OKLeoeg5w97JoYiOnhrVu0f71TMUUeQc21AEfJYOdzXlQdWVlyS7fhlhL32uPKSacGjWHqMFLkYuv4ZZJ2mPc26fJJnZE5LukJQOYc6545KOm9kfT/VeAAAAQEqGvvkVpZpfkdvG9M45DUfi5w17J4ci6h2OaGdXv3qHIuobiZ5z24nykP+cnb26CR2+2nBIVWVBFnnBrMslwC2UdCjjuFPS1Tm+f873mtndku6WpJaWlhzfHgAAAHOVmam8JKDykoAW14ZzuieecOobToa6k0NRnRwaO+Nrb+qRz57BiN4/NqiTQxGNRONZ38tnUk04I9hNfLQzY/GW8S5gWYguHy5OLgEu268Vcl3KKOd7nXOPSnpUktrb2y98qSQAAADgHPw+U928EtVNYd/BkUg8a3evd8L8vr3dg+o9kHx9rjVcSoO+9KObmYu3ZC7aMt7lq0k92sn+f8iUS4DrlLQ443iRpCM5vv/F3AsAAAB4rizk18JQmRZWl+U0PpFw6h+NnifspTp+w1Ht7xlS71BUg2PZF3Axk6rKTs/dC5cEFPSZ/D5TwG8K+HwKTDj2+0wBnyUXr0ldC/pNfl/244A/dX/q3mDGcfI9sx8H/Wd+r/HX49d8JhaSmQG5BLitktrMbJmkw5I+LenPcnz/i7kXAAAAKDg+n6k6tf3B8obc7hmLxdOLt0zs7I3P7zs5GNGpkajiiYRicadYwimecIpNOI7GE6nzyePzbekw0wJZg2XGcToQ+jKuZQmIPp/8mdcyjseDZebYM49N/lSYPR1YfekgfOWSGtVPoSPrtUkDnHMuZmb3Sdqs5FYAjzvndpjZPanrj5hZk6QOSZWSEmb2V5LWOOf6s907Q58FAAAAKEglAb+aqvxqqsptAZepSCSc4u7scJcMfdmPo/HT4TCevuYUTyQyrp19HIsnMoJl8jjbteQ92Y9jqfcbisXOqiXzemy8vvjp1xeyaf0/fG6dbrokx6SdB8ydaxkeD7W3t7uOjg6vywAAAABQYBIZ4S6WcIrHnaIZQXRiGFxSF1ZFadDrss9iZtucc+0Tz8/ebo4AAAAAMMN8PlPIZwrJ53UpM6I4PxUAAAAAFCECHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAhzznldw1nMrFvSAa/ryKJe0gmviwDOg59R5Dt+RpHv+BlFvuNndO5Y4pxrmHgyLwNcvjKzDudcu9d1AOfCzyjyHT+jyHf8jCLf8TMKHqEEAAAAgAJBgAMAAACAAkGAm5pHvS4AmAQ/o8h3/Iwi3/EzinzHz+gcxxw4AAAAACgQdOAAAAAAoEAQ4HJgZreZ2S4z22NmX/G6HiCTmS02sxfMbKeZ7TCzB72uCcjGzPxm9qaZ/cLrWoCJzKzazJ40s/dSf59e63VNQCYz+99S/5/fbmY/MrNSr2uCNwhwkzAzv6SHJd0uaY2ku8xsjbdVAWeISfqPzrnVkq6RdC8/o8hTD0ra6XURwDk8JOk3zrlVki4XP6vII2a2UNIDktqdc5dK8kv6tLdVwSsEuMmtk7THObfPOReR9ISkOzyuCUhzzh11zr2Rej2g5D86FnpbFXAmM1sk6Y8lPeZ1LcBEZlYp6UZJ35Mk51zEOdfnaVHA2QKSyswsICks6YjH9cAjBLjJLZR0KOO4U/zjGHnKzJZK+rCk33tcCjDR1yX9H5ISHtcBZLNcUrek76ce833MzMq9LgoY55w7LOlvJR2UdFTSKefcM95WBa8Q4CZnWc6xdCfyjpnNk/QTSX/lnOv3uh5gnJl9XNJx59w2r2sBziEg6UpJ33HOfVjSkCTmvCNvmFmNkk+ALZO0QFK5mf17b6uCVwhwk+uUtDjjeJFoWSPPmFlQyfD2Q+fcT72uB5jgekl/Ymb7lXwMfYOZ/ZO3JQFn6JTU6Zwbf3rhSSUDHZAvPibpA+dct3MuKumnkq7zuCZ4hAA3ua2S2sxsmZmFlJww+rTHNQFpZmZKztvY6Zz7H17XA0zknPs/nXOLnHNLlfw79HnnHL85Rt5wznVJOmRmK1OnNkp618OSgIkOSrrGzMKp/+9vFAvtzFkBrwvId865mJndJ2mzkiv+PO6c2+FxWUCm6yX9uaQ/mNlbqXP/yTn3K+9KAoCCc7+kH6Z+WbtP0l94XA+Q5pz7vZk9KekNJVefflPSo95WBa+Yc0znAgAAAIBCwCOUAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCAIcAAAAABQIAhwAAAAAFAgCHAAAAAAUCD+f8y14DV/AZWvAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 1080x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"### Plotting the cost vs epochs\n",
"fig, ax = plt.subplots(figsize=(15, 5))\n",
"plt.plot(np.array(cost))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Assignment1-Step2_Harshit_Agarwal.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"037faad73a8a40499555406f4e586731": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"0554687d97424798a86ea0a4c56cdbf8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"14d009d293864f9f85bd16b5e1b6b381": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c75452bfc2264392bf3d842bfbd4eeee",
"placeholder": "",
"style": "IPY_MODEL_0554687d97424798a86ea0a4c56cdbf8",
"value": " 29696/? [00:00<00:00, 395418.47it/s]"
}
},
"1a89067cd68c41fb96b74e0ebf3d1931": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_91798827c9254ffc969228862c8ee37f",
"IPY_MODEL_cdbc8de66e8c4fa79100dcac1fcfed4d",
"IPY_MODEL_2773f678ca4a45158df219d2672fa646"
],
"layout": "IPY_MODEL_9f7585b1023d4e20ba3649efcfcdb881"
}
},
"2773f678ca4a45158df219d2672fa646": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_480e369fb79f4326bd13d2355fddd890",
"placeholder": "",
"style": "IPY_MODEL_898cdc388a2d428ca742a9ba2365df20",
"value": " 5120/? [00:00<00:00, 57213.75it/s]"
}
},
"2ee0648f054c49049fdf5d6ac81ec086": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_4be171eff56046c2a401bf02b6d704c9",
"IPY_MODEL_c8c6df38201b470bb06bffc677873f93",
"IPY_MODEL_f51f06fbe28746dea09e075256e29451"
],
"layout": "IPY_MODEL_7c90cc0c4c3b4cd89dbdfffdb78d463d"
}
},
"3d3af4241b714b9a87253ac87b2e31b1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"44f9d2dde4424058a337a1e3d585b5fe": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"47b04fa0a67c4fbbbcdc3d1062430259": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"480e369fb79f4326bd13d2355fddd890": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4b80b4415ce34130a048933b2df179a5": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4be171eff56046c2a401bf02b6d704c9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_9ee4f57a58c14bb1880789196ff7ef63",
"placeholder": "",
"style": "IPY_MODEL_74d1d431fbb84c4f948ca510e397f4e2",
"value": ""
}
},
"51567646aecc4126a3b9cb96f97d5be5": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_037faad73a8a40499555406f4e586731",
"placeholder": "",
"style": "IPY_MODEL_4b80b4415ce34130a048933b2df179a5",
"value": " 9913344/? [00:00<00:00, 33092818.93it/s]"
}
},
"54ce0f87fc1f46d2b9f9850a393b46cf": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"553cf4ae8edb4b48a58ab4446036ae87": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5ecbc7b6eb6544709da61283dfc8d3c6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"6e770d6f5e5f4101b2c985f8182eb79e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_af12215b0ba24c699e1111c132dee3fe",
"placeholder": "",
"style": "IPY_MODEL_908f60d17e1b49869ac3d4ee53a74a6e",
"value": ""
}
},
"74d1d431fbb84c4f948ca510e397f4e2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"780790f42e764ec5a0b8d441a290e5d3": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7837b9def9dd470cb57f96b38afbfc18": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"7c8cf011ed684ef0841d41da76c9bfc8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d2e31ba6e42448839959f523cc56dbcb",
"max": 28881,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_5ecbc7b6eb6544709da61283dfc8d3c6",
"value": 28881
}
},
"7c90cc0c4c3b4cd89dbdfffdb78d463d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"84bda6e8c38b4b669e2825c962e9dbfb": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"898cdc388a2d428ca742a9ba2365df20": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"908f60d17e1b49869ac3d4ee53a74a6e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"90f550d2f7344eb894208d0373ac6f8d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"91798827c9254ffc969228862c8ee37f": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3d3af4241b714b9a87253ac87b2e31b1",
"placeholder": "",
"style": "IPY_MODEL_84bda6e8c38b4b669e2825c962e9dbfb",
"value": ""
}
},
"9ee4f57a58c14bb1880789196ff7ef63": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9f7585b1023d4e20ba3649efcfcdb881": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a1dd0d1a7fb14235a271992bea0d233c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"af12215b0ba24c699e1111c132dee3fe": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b3105726704240439adcf7f13bd48cca": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b9fa1995789a49fdb6a0c3e8505733de": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c52318a206de4565a725388bc50bdfc0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_d054c2e9330c433ab6bb900c6fa7dac6",
"IPY_MODEL_7c8cf011ed684ef0841d41da76c9bfc8",
"IPY_MODEL_14d009d293864f9f85bd16b5e1b6b381"
],
"layout": "IPY_MODEL_553cf4ae8edb4b48a58ab4446036ae87"
}
},
"c75452bfc2264392bf3d842bfbd4eeee": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c8c6df38201b470bb06bffc677873f93": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_90f550d2f7344eb894208d0373ac6f8d",
"max": 1648877,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_7837b9def9dd470cb57f96b38afbfc18",
"value": 1648877
}
},
"cdbc8de66e8c4fa79100dcac1fcfed4d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_780790f42e764ec5a0b8d441a290e5d3",
"max": 4542,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_47b04fa0a67c4fbbbcdc3d1062430259",
"value": 4542
}
},
"d054c2e9330c433ab6bb900c6fa7dac6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a1dd0d1a7fb14235a271992bea0d233c",
"placeholder": "",
"style": "IPY_MODEL_b9fa1995789a49fdb6a0c3e8505733de",
"value": ""
}
},
"d2e31ba6e42448839959f523cc56dbcb": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d784c3bdbaf543a299447b17b500c2a8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f6af65953a9841e1907a47fb27c01b96",
"max": 9912422,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_44f9d2dde4424058a337a1e3d585b5fe",
"value": 9912422
}
},
"f1ae475d1e48411ab7cb9f49f6f673c9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6e770d6f5e5f4101b2c985f8182eb79e",
"IPY_MODEL_d784c3bdbaf543a299447b17b500c2a8",
"IPY_MODEL_51567646aecc4126a3b9cb96f97d5be5"
],
"layout": "IPY_MODEL_f65a4ec103c34c7186fb76d3e3507795"
}
},
"f51f06fbe28746dea09e075256e29451": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b3105726704240439adcf7f13bd48cca",
"placeholder": "",
"style": "IPY_MODEL_54ce0f87fc1f46d2b9f9850a393b46cf",
"value": " 1649664/? [00:00<00:00, 2122626.57it/s]"
}
},
"f65a4ec103c34c7186fb76d3e3507795": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f6af65953a9841e1907a47fb27c01b96": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
CNN : https://en.wikipedia.org/wiki/Convolutional_neural_network
Pytorch : https://pytorch.ac.cn/tutorials/beginner/basics/intro.html
首先使用 DataLoader 和 torchvision 提供的数据集导入训练和测试 MNIST 数据集。您可以将训练和测试批次大小设置为您认为最合适的任何值。
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Downloading MNIST dataset from Pytorch
dataset = datasets.MNIST(
root="./data",
download=True,
train=True,
transform=transforms.ToTensor(),
)
# Splitting the dataset into training and testing set
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 10000])
# Visualizing a sample from dataset
plt.imshow(train_dataset.dataset.data[10])
plt.title("Label : " + str(train_dataset.dataset.train_labels[10].item()))
/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets
warnings.warn("train_labels has been renamed targets")
Text(0.5, 1.0, 'Label : 3')
# Creating a DataLoader for training and testing
train = DataLoader(train_dataset, batch_size=32, shuffle=True)
test = DataLoader(test_dataset, batch_size=1, shuffle=True)
定义具有以下架构的网络
Conv2d (输入通道=1,输出通道 = 15,内核大小 = 5) → MaxPool (内核大小 = 2) → ReLU → Conv2d (输入通道=15,输出通道 = 30,内核大小 = 5) → Dropout2d (p = 0.5) → MaxPool (内核大小 = 2) → ReLU → Linear(输入维度 = 480,隐藏单元 = 64) → ReLU → Dropout (p=0.5) → Linear(输入维度 = 64,隐藏单元 = 10) → LogSoftMax
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 15, kernel_size=5),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Conv2d(15, 30, kernel_size=5),
nn.Dropout2d(0.5),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(480, 64),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(64, 10),
nn.LogSoftmax(dim=1),
)
def forward(self, x):
return self.cnn(x)
在 MNIST 上训练您在上一个问题中定义的网络,使用您认为合适的优化器和训练轮数。使用交叉熵损失。每个 epoch 在测试数据集上测试您的模型,并打印您获得的准确率值。
batch_size = 32
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Running the model on GPU if available
## Refer pytorch documentation for more details about copying model and data onto the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
cost = []
epochs = 10
# Training the model
for epoch in range(epochs):
loss_epoch = []
train_acc = []
for x, y in train:
# Predicting the output
y_pred = model(x.to(device))
# Converting the predicted output from one hot encoding to a single number
_, t_preds = torch.max(y_pred, dim=1)
# Calculating the training accuracy
train_acc.append(
torch.tensor(torch.sum(t_preds == y.to(device)).item() / len(t_preds))
)
# Calculating the loss
loss = F.cross_entropy(y_pred, y.type(torch.LongTensor).to(device))
# Backpropagation
# Zeroing the gradients
optimizer.zero_grad()
# Calculating the gradients
loss.backward()
# Updating the weights
optimizer.step()
# Appending the loss of each batch to the epoch loss
loss_epoch.append(loss.item())
# Calculating test accuracy
with torch.no_grad():
if epoch % 1 == 0:
test_acc = []
for x, y in test:
y_pred = model(x.to(device))
_, t_preds = torch.max(y_pred, dim=1)
test_acc.append(
torch.tensor(
torch.sum(t_preds == y.to(device)).item() / len(t_preds)
)
)
print(
"Epoch :{} Loss : {} Train Accuracy:{} Test Accuracy : {}".format(
epoch,
sum(loss_epoch) / len(loss_epoch),
sum(train_acc) / len(train_acc),
sum(test_acc) / len(test_acc),
)
)
cost.append(sum(loss_epoch) / len(loss_epoch))
/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch :0 Loss : 0.44618519723548844 Train Accuracy:0.8599048256874084 Test Accuracy : 0.9340000152587891
Epoch :1 Loss : 0.20212347584169404 Train Accuracy:0.9430182576179504 Test Accuracy : 0.9490000009536743
Epoch :2 Loss : 0.16394136824853056 Train Accuracy:0.952215313911438 Test Accuracy : 0.9531000256538391
Epoch :3 Loss : 0.14257464571382256 Train Accuracy:0.9595929384231567 Test Accuracy : 0.9588000178337097
Epoch :4 Loss : 0.12598471959617277 Train Accuracy:0.9644113779067993 Test Accuracy : 0.9648000001907349
Epoch :5 Loss : 0.11733871379403024 Train Accuracy:0.9659308791160583 Test Accuracy : 0.9660000205039978
Epoch :6 Loss : 0.11000267015220401 Train Accuracy:0.9681901931762695 Test Accuracy : 0.9664999842643738
Epoch :7 Loss : 0.10582590816269605 Train Accuracy:0.9684301018714905 Test Accuracy : 0.9631999731063843
Epoch :8 Loss : 0.09624670598793283 Train Accuracy:0.9711892008781433 Test Accuracy : 0.9706000089645386
Epoch :9 Loss : 0.09422491162643551 Train Accuracy:0.9726687669754028 Test Accuracy : 0.9679999947547913
### Plotting the cost vs epochs
fig, ax = plt.subplots(figsize=(15, 5))
plt.plot(np.array(cost))
[<matplotlib.lines.Line2D at 0x7fa94052aa00>]