{
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"papermill": {
"default_parameters": {},
"duration": 21629.871921,
"end_time": "2021-08-01T03:02:57.256605",
"environment_variables": {},
"exception": null,
"input_path": "__notebook__.ipynb",
"output_path": "__notebook__.ipynb",
"parameters": {},
"start_time": "2021-07-31T21:02:27.384684",
"version": "2.3.3"
},
"colab": {
"name": "text classification using BERT.ipynb",
"provenance": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Lk52n6Tt0uJO"
},
"source": [
"#Text Classification using BERT\n",
"\n",
"The is a basic implementation of text classification pipeline using BERT. The BERT model used has been taken from [huggingface](https://hugging-face.cn/transformers/). The dataset used is a custom dataset with two classes (labelled as 0 and 1). It is publically available [here](https://raw.githubusercontent.com/prateekjoshi565/Fine-Tuning-BERT/master/spamdata_v2.csv)."
],
"id": "Lk52n6Tt0uJO"
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:02:35.719217Z",
"iopub.status.busy": "2021-07-31T21:02:35.717605Z",
"iopub.status.idle": "2021-07-31T21:02:35.722099Z",
"shell.execute_reply": "2021-07-31T21:02:35.722966Z",
"shell.execute_reply.started": "2021-07-31T08:17:49.337897Z"
},
"papermill": {
"duration": 0.024617,
"end_time": "2021-07-31T21:02:35.723409",
"exception": false,
"start_time": "2021-07-31T21:02:35.698792",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://127.0.0.1:8080/"
},
"id": "788d3220",
"outputId": "54b8c951-4479-4581-9f17-d3f7e23bd477"
},
"source": [
"!pip install transformers"
],
"id": "788d3220",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting transformers\n",
" Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)\n",
"\u001b[K |████████████████████████████████| 2.8 MB 4.0 MB/s \n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n",
"Collecting huggingface-hub>=0.0.12\n",
" Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)\n",
"\u001b[K |████████████████████████████████| 50 kB 6.8 MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n",
"Collecting pyyaml>=5.1\n",
" Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)\n",
"\u001b[K |████████████████████████████████| 636 kB 51.9 MB/s \n",
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n",
"Collecting tokenizers<0.11,>=0.10.1\n",
" Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n",
"\u001b[K |████████████████████████████████| 3.3 MB 21.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
"Collecting sacremoses\n",
" Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)\n",
"\u001b[K |████████████████████████████████| 895 kB 45.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n",
"Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers\n",
" Attempting uninstall: pyyaml\n",
" Found existing installation: PyYAML 3.13\n",
" Uninstalling PyYAML-3.13:\n",
" Successfully uninstalled PyYAML-3.13\n",
"Successfully installed huggingface-hub-0.0.16 pyyaml-5.4.1 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.10.0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:05:44.436123Z",
"iopub.status.busy": "2021-07-31T21:05:44.435351Z",
"iopub.status.idle": "2021-07-31T21:06:30.493736Z",
"shell.execute_reply": "2021-07-31T21:06:30.492714Z",
"shell.execute_reply.started": "2021-07-31T08:21:08.023057Z"
},
"papermill": {
"duration": 46.206213,
"end_time": "2021-07-31T21:06:30.494123",
"exception": false,
"start_time": "2021-07-31T21:05:44.287910",
"status": "completed"
},
"tags": [],
"id": "c90ffee0"
},
"source": [
"import csv\n",
"import pickle\n",
"import pandas as pd\n",
"import numpy as np\n",
"train = pd.read_csv(\"spamdata_v2.csv\")\n"
],
"id": "c90ffee0",
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:30.792497Z",
"iopub.status.busy": "2021-07-31T21:06:30.791325Z",
"iopub.status.idle": "2021-07-31T21:06:31.512114Z",
"shell.execute_reply": "2021-07-31T21:06:31.512646Z",
"shell.execute_reply.started": "2021-07-31T08:21:53.599755Z"
},
"papermill": {
"duration": 0.877222,
"end_time": "2021-07-31T21:06:31.512821",
"exception": false,
"start_time": "2021-07-31T21:06:30.635599",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://127.0.0.1:8080/",
"height": 219
},
"id": "fdb5210b",
"outputId": "421a3979-dfdd-4a0f-d331-4b328ba545f3"
},
"source": [
"print(len(train))\n",
"train.head()"
],
"id": "fdb5210b",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"5572\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label text\n",
"0 0 Go until jurong point, crazy.. Available only ...\n",
"1 0 Ok lar... Joking wif u oni...\n",
"2 1 Free entry in 2 a wkly comp to win FA Cup fina...\n",
"3 0 U dun say so early hor... U c already then say...\n",
"4 0 Nah I don't think he goes to usf, he lives aro..."
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mxdJUOsolsMj"
},
"source": [
"num_classes = 2"
],
"id": "mxdJUOsolsMj",
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:31.796699Z",
"iopub.status.busy": "2021-07-31T21:06:31.795920Z",
"iopub.status.idle": "2021-07-31T21:06:39.859089Z",
"shell.execute_reply": "2021-07-31T21:06:39.859561Z",
"shell.execute_reply.started": "2021-07-31T08:21:54.413397Z"
},
"papermill": {
"duration": 8.208197,
"end_time": "2021-07-31T21:06:39.859775",
"exception": false,
"start_time": "2021-07-31T21:06:31.651578",
"status": "completed"
},
"tags": [],
"id": "eb38df80"
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"train_split, val_split = train_test_split(train, test_size=.05)"
],
"id": "eb38df80",
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:40.148699Z",
"iopub.status.busy": "2021-07-31T21:06:40.147996Z",
"iopub.status.idle": "2021-07-31T21:06:40.188820Z",
"shell.execute_reply": "2021-07-31T21:06:40.188250Z",
"shell.execute_reply.started": "2021-07-31T08:22:01.764055Z"
},
"papermill": {
"duration": 0.187948,
"end_time": "2021-07-31T21:06:40.188979",
"exception": false,
"start_time": "2021-07-31T21:06:40.001031",
"status": "completed"
},
"tags": [],
"id": "a9bbf53d"
},
"source": [
"from transformers import BertTokenizerFast\n",
"tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')"
],
"id": "a9bbf53d",
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:40.486360Z",
"iopub.status.busy": "2021-07-31T21:06:40.485678Z",
"iopub.status.idle": "2021-07-31T21:06:51.481509Z",
"shell.execute_reply": "2021-07-31T21:06:51.480780Z",
"shell.execute_reply.started": "2021-07-31T08:22:01.807511Z"
},
"papermill": {
"duration": 11.150744,
"end_time": "2021-07-31T21:06:51.481655",
"exception": false,
"start_time": "2021-07-31T21:06:40.330911",
"status": "completed"
},
"tags": [],
"id": "f91e37b7"
},
"source": [
"import torch\n",
"\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, df, tokenizer, max_length=128):\n",
" self.df = df\n",
" self.text = df.text.values\n",
" self.labels = df.label.values\n",
" self.tokenizer = tokenizer\n",
" self.max_length = max_length\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" tokenized_data = tokenizer.tokenize(self.text[idx])\n",
" to_append = [\"[CLS]\"] + tokenized_data[:self.max_length - 2] + [\"[SEP]\"]\n",
" input_ids = tokenizer.convert_tokens_to_ids(to_append)\n",
" input_mask = [1] * len(input_ids)\n",
" padding = [0] * (self.max_length - len(input_ids))\n",
" input_ids += padding\n",
" input_mask += padding\n",
" item = {\n",
" \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n",
" \"attention_mask\": torch.tensor(input_mask, dtype=torch.long)\n",
" }\n",
" item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)\n",
" return item\n",
" \n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
"train_dataset = Dataset(train_split.fillna(\"\"), tokenizer)\n",
"val_dataset = Dataset(val_split.fillna(\"\"), tokenizer)\n",
"# train_dataset = Dataset(train.fillna(\"\"), tokenizer, is_train=True, label_map=label_map)\n",
"# test_dataset = Dataset(test.fillna(\"\"), tokenizer, is_train=False)"
],
"id": "f91e37b7",
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:51.772366Z",
"iopub.status.busy": "2021-07-31T21:06:51.771623Z",
"iopub.status.idle": "2021-07-31T21:07:23.576720Z",
"shell.execute_reply": "2021-07-31T21:07:23.576103Z",
"shell.execute_reply.started": "2021-07-31T08:22:12.696992Z"
},
"papermill": {
"duration": 31.953274,
"end_time": "2021-07-31T21:07:23.576885",
"exception": false,
"start_time": "2021-07-31T21:06:51.623611",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://127.0.0.1:8080/"
},
"id": "6186779d",
"outputId": "9e0874d1-9977-4b48-9c98-3d19e5d4ccd0"
},
"source": [
"from transformers import BertForSequenceClassification, Trainer, TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir='./results', # output directory\n",
" num_train_epochs=50, # total number of training epochs\n",
" per_device_train_batch_size=64, # batch size per device during training\n",
" per_device_eval_batch_size=64, # batch size for evaluation\n",
" warmup_steps=500, # number of warmup steps for learning rate scheduler\n",
" weight_decay=0.01, # strength of weight decay\n",
" logging_dir='./logs', # directory for storing logs\n",
" logging_steps=100,\n",
" dataloader_num_workers=2,\n",
" report_to=\"tensorboard\",\n",
" label_smoothing_factor=0.1,\n",
" evaluation_strategy=\"steps\",\n",
" eval_steps=500, # Evaluation and Save happens every 500 steps\n",
" save_total_limit=3, # Only last 5 models are saved. Older ones are deleted.\n",
" load_best_model_at_end=True, #best model is always saved\n",
")\n",
"\n",
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\")\n",
"model.classifier = torch.nn.Linear(768, num_classes)\n",
"model.num_labels = num_classes"
],
"id": "6186779d",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:07:23.873644Z",
"iopub.status.busy": "2021-07-31T21:07:23.872942Z",
"iopub.status.idle": "2021-08-01T03:02:53.329971Z",
"shell.execute_reply": "2021-08-01T03:02:53.330600Z"
},
"papermill": {
"duration": 21329.605812,
"end_time": "2021-08-01T03:02:53.332339",
"exception": false,
"start_time": "2021-07-31T21:07:23.726527",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://127.0.0.1:8080/",
"height": 675
},
"id": "13abf542",
"outputId": "42a52c37-29f9-4f43-a70f-c285886f0c10"
},
"source": [
"trainer = Trainer(\n",
" model=model, # the instantiated 🤗 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=train_dataset, # training dataset\n",
" eval_dataset=val_dataset # evaluation dataset\n",
")\n",
"trainer.train()"
],
"id": "13abf542",
"execution_count": null,
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running training *****\n",
" Num examples = 5293\n",
" Num Epochs = 50\n",
" Instantaneous batch size per device = 64\n",
" Total train batch size (w. parallel, distributed & accumulation) = 64\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 4150\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='408' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 408/4150 16:23 < 2:31:08, 0.41 it/s, Epoch 4.90/50]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='1507' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [1507/4150 1:01:13 < 1:47:31, 0.41 it/s, Epoch 18.14/50]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>0.243900</td>\n",
" <td>0.221590</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>0.200800</td>\n",
" <td>0.217612</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
" <td>0.198900</td>\n",
" <td>0.218175</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-500\n",
"Configuration saved in ./results/checkpoint-500/config.json\n",
"Model weights saved in ./results/checkpoint-500/pytorch_model.bin\n",
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-1000\n",
"Configuration saved in ./results/checkpoint-1000/config.json\n",
"Model weights saved in ./results/checkpoint-1000/pytorch_model.bin\n",
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-1500\n",
"Configuration saved in ./results/checkpoint-1500/config.json\n",
"Model weights saved in ./results/checkpoint-1500/pytorch_model.bin\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lsoxR09ymdHv"
},
"source": [
""
],
"id": "lsoxR09ymdHv",
"execution_count": null,
"outputs": []
}
]
}