The Algorithms logo
算法
关于我们捐赠

使用 BERT 进行文本分类

H
{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.10"
    },
    "papermill": {
      "default_parameters": {},
      "duration": 21629.871921,
      "end_time": "2021-08-01T03:02:57.256605",
      "environment_variables": {},
      "exception": null,
      "input_path": "__notebook__.ipynb",
      "output_path": "__notebook__.ipynb",
      "parameters": {},
      "start_time": "2021-07-31T21:02:27.384684",
      "version": "2.3.3"
    },
    "colab": {
      "name": "text classification using BERT.ipynb",
      "provenance": []
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lk52n6Tt0uJO"
      },
      "source": [
        "#Text Classification using BERT\n",
        "\n",
        "The is a basic implementation of text classification pipeline using BERT. The BERT model used has been taken from [huggingface](https://hugging-face.cn/transformers/). The dataset used is a custom dataset with two classes (labelled as 0 and 1). It is publically available [here](https://raw.githubusercontent.com/prateekjoshi565/Fine-Tuning-BERT/master/spamdata_v2.csv)."
      ],
      "id": "Lk52n6Tt0uJO"
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:02:35.719217Z",
          "iopub.status.busy": "2021-07-31T21:02:35.717605Z",
          "iopub.status.idle": "2021-07-31T21:02:35.722099Z",
          "shell.execute_reply": "2021-07-31T21:02:35.722966Z",
          "shell.execute_reply.started": "2021-07-31T08:17:49.337897Z"
        },
        "papermill": {
          "duration": 0.024617,
          "end_time": "2021-07-31T21:02:35.723409",
          "exception": false,
          "start_time": "2021-07-31T21:02:35.698792",
          "status": "completed"
        },
        "tags": [],
        "colab": {
          "base_uri": "https://127.0.0.1:8080/"
        },
        "id": "788d3220",
        "outputId": "54b8c951-4479-4581-9f17-d3f7e23bd477"
      },
      "source": [
        "!pip install transformers"
      ],
      "id": "788d3220",
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting transformers\n",
            "  Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 2.8 MB 4.0 MB/s \n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n",
            "Collecting huggingface-hub>=0.0.12\n",
            "  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)\n",
            "\u001b[K     |████████████████████████████████| 50 kB 6.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n",
            "Collecting pyyaml>=5.1\n",
            "  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)\n",
            "\u001b[K     |████████████████████████████████| 636 kB 51.9 MB/s \n",
            "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n",
            "Collecting tokenizers<0.11,>=0.10.1\n",
            "  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 3.3 MB 21.4 MB/s \n",
            "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
            "Collecting sacremoses\n",
            "  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)\n",
            "\u001b[K     |████████████████████████████████| 895 kB 45.3 MB/s \n",
            "\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n",
            "Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers\n",
            "  Attempting uninstall: pyyaml\n",
            "    Found existing installation: PyYAML 3.13\n",
            "    Uninstalling PyYAML-3.13:\n",
            "      Successfully uninstalled PyYAML-3.13\n",
            "Successfully installed huggingface-hub-0.0.16 pyyaml-5.4.1 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.10.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:05:44.436123Z",
          "iopub.status.busy": "2021-07-31T21:05:44.435351Z",
          "iopub.status.idle": "2021-07-31T21:06:30.493736Z",
          "shell.execute_reply": "2021-07-31T21:06:30.492714Z",
          "shell.execute_reply.started": "2021-07-31T08:21:08.023057Z"
        },
        "papermill": {
          "duration": 46.206213,
          "end_time": "2021-07-31T21:06:30.494123",
          "exception": false,
          "start_time": "2021-07-31T21:05:44.287910",
          "status": "completed"
        },
        "tags": [],
        "id": "c90ffee0"
      },
      "source": [
        "import csv\n",
        "import pickle\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "train = pd.read_csv(\"spamdata_v2.csv\")\n"
      ],
      "id": "c90ffee0",
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:06:30.792497Z",
          "iopub.status.busy": "2021-07-31T21:06:30.791325Z",
          "iopub.status.idle": "2021-07-31T21:06:31.512114Z",
          "shell.execute_reply": "2021-07-31T21:06:31.512646Z",
          "shell.execute_reply.started": "2021-07-31T08:21:53.599755Z"
        },
        "papermill": {
          "duration": 0.877222,
          "end_time": "2021-07-31T21:06:31.512821",
          "exception": false,
          "start_time": "2021-07-31T21:06:30.635599",
          "status": "completed"
        },
        "tags": [],
        "colab": {
          "base_uri": "https://127.0.0.1:8080/",
          "height": 219
        },
        "id": "fdb5210b",
        "outputId": "421a3979-dfdd-4a0f-d331-4b328ba545f3"
      },
      "source": [
        "print(len(train))\n",
        "train.head()"
      ],
      "id": "fdb5210b",
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5572\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>label</th>\n",
              "      <th>text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0</td>\n",
              "      <td>Go until jurong point, crazy.. Available only ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0</td>\n",
              "      <td>Ok lar... Joking wif u oni...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>1</td>\n",
              "      <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0</td>\n",
              "      <td>U dun say so early hor... U c already then say...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0</td>\n",
              "      <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   label                                               text\n",
              "0      0  Go until jurong point, crazy.. Available only ...\n",
              "1      0                      Ok lar... Joking wif u oni...\n",
              "2      1  Free entry in 2 a wkly comp to win FA Cup fina...\n",
              "3      0  U dun say so early hor... U c already then say...\n",
              "4      0  Nah I don't think he goes to usf, he lives aro..."
            ]
          },
          "metadata": {},
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mxdJUOsolsMj"
      },
      "source": [
        "num_classes = 2"
      ],
      "id": "mxdJUOsolsMj",
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:06:31.796699Z",
          "iopub.status.busy": "2021-07-31T21:06:31.795920Z",
          "iopub.status.idle": "2021-07-31T21:06:39.859089Z",
          "shell.execute_reply": "2021-07-31T21:06:39.859561Z",
          "shell.execute_reply.started": "2021-07-31T08:21:54.413397Z"
        },
        "papermill": {
          "duration": 8.208197,
          "end_time": "2021-07-31T21:06:39.859775",
          "exception": false,
          "start_time": "2021-07-31T21:06:31.651578",
          "status": "completed"
        },
        "tags": [],
        "id": "eb38df80"
      },
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "train_split, val_split = train_test_split(train, test_size=.05)"
      ],
      "id": "eb38df80",
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:06:40.148699Z",
          "iopub.status.busy": "2021-07-31T21:06:40.147996Z",
          "iopub.status.idle": "2021-07-31T21:06:40.188820Z",
          "shell.execute_reply": "2021-07-31T21:06:40.188250Z",
          "shell.execute_reply.started": "2021-07-31T08:22:01.764055Z"
        },
        "papermill": {
          "duration": 0.187948,
          "end_time": "2021-07-31T21:06:40.188979",
          "exception": false,
          "start_time": "2021-07-31T21:06:40.001031",
          "status": "completed"
        },
        "tags": [],
        "id": "a9bbf53d"
      },
      "source": [
        "from transformers import BertTokenizerFast\n",
        "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')"
      ],
      "id": "a9bbf53d",
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:06:40.486360Z",
          "iopub.status.busy": "2021-07-31T21:06:40.485678Z",
          "iopub.status.idle": "2021-07-31T21:06:51.481509Z",
          "shell.execute_reply": "2021-07-31T21:06:51.480780Z",
          "shell.execute_reply.started": "2021-07-31T08:22:01.807511Z"
        },
        "papermill": {
          "duration": 11.150744,
          "end_time": "2021-07-31T21:06:51.481655",
          "exception": false,
          "start_time": "2021-07-31T21:06:40.330911",
          "status": "completed"
        },
        "tags": [],
        "id": "f91e37b7"
      },
      "source": [
        "import torch\n",
        "\n",
        "class Dataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, df, tokenizer, max_length=128):\n",
        "        self.df = df\n",
        "        self.text = df.text.values\n",
        "        self.labels = df.label.values\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_length = max_length\n",
        " \n",
        "    def __getitem__(self, idx):\n",
        "        \n",
        "        tokenized_data = tokenizer.tokenize(self.text[idx])\n",
        "        to_append = [\"[CLS]\"] + tokenized_data[:self.max_length - 2] + [\"[SEP]\"]\n",
        "        input_ids = tokenizer.convert_tokens_to_ids(to_append)\n",
        "        input_mask = [1] * len(input_ids)\n",
        "        padding = [0] * (self.max_length - len(input_ids))\n",
        "        input_ids += padding\n",
        "        input_mask += padding\n",
        "        item = {\n",
        "            \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n",
        "            \"attention_mask\": torch.tensor(input_mask, dtype=torch.long)\n",
        "        }\n",
        "        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)\n",
        "        return item\n",
        " \n",
        "    def __len__(self):\n",
        "        return len(self.df)\n",
        "\n",
        "train_dataset = Dataset(train_split.fillna(\"\"), tokenizer)\n",
        "val_dataset = Dataset(val_split.fillna(\"\"), tokenizer)\n",
        "# train_dataset = Dataset(train.fillna(\"\"), tokenizer, is_train=True, label_map=label_map)\n",
        "# test_dataset = Dataset(test.fillna(\"\"), tokenizer, is_train=False)"
      ],
      "id": "f91e37b7",
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:06:51.772366Z",
          "iopub.status.busy": "2021-07-31T21:06:51.771623Z",
          "iopub.status.idle": "2021-07-31T21:07:23.576720Z",
          "shell.execute_reply": "2021-07-31T21:07:23.576103Z",
          "shell.execute_reply.started": "2021-07-31T08:22:12.696992Z"
        },
        "papermill": {
          "duration": 31.953274,
          "end_time": "2021-07-31T21:07:23.576885",
          "exception": false,
          "start_time": "2021-07-31T21:06:51.623611",
          "status": "completed"
        },
        "tags": [],
        "colab": {
          "base_uri": "https://127.0.0.1:8080/"
        },
        "id": "6186779d",
        "outputId": "9e0874d1-9977-4b48-9c98-3d19e5d4ccd0"
      },
      "source": [
        "from transformers import BertForSequenceClassification, Trainer, TrainingArguments\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    output_dir='./results',          # output directory\n",
        "    num_train_epochs=50,              # total number of training epochs\n",
        "    per_device_train_batch_size=64, # batch size per device during training\n",
        "    per_device_eval_batch_size=64,  # batch size for evaluation\n",
        "    warmup_steps=500,               # number of warmup steps for learning rate scheduler\n",
        "    weight_decay=0.01,               # strength of weight decay\n",
        "    logging_dir='./logs',            # directory for storing logs\n",
        "    logging_steps=100,\n",
        "    dataloader_num_workers=2,\n",
        "    report_to=\"tensorboard\",\n",
        "    label_smoothing_factor=0.1,\n",
        "    evaluation_strategy=\"steps\",\n",
        "    eval_steps=500, # Evaluation and Save happens every 500 steps\n",
        "    save_total_limit=3, # Only last 5 models are saved. Older ones are deleted.\n",
        "    load_best_model_at_end=True,   #best model is always saved\n",
        ")\n",
        "\n",
        "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\")\n",
        "model.classifier = torch.nn.Linear(768, num_classes)\n",
        "model.num_labels = num_classes"
      ],
      "id": "6186779d",
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
            "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-07-31T21:07:23.873644Z",
          "iopub.status.busy": "2021-07-31T21:07:23.872942Z",
          "iopub.status.idle": "2021-08-01T03:02:53.329971Z",
          "shell.execute_reply": "2021-08-01T03:02:53.330600Z"
        },
        "papermill": {
          "duration": 21329.605812,
          "end_time": "2021-08-01T03:02:53.332339",
          "exception": false,
          "start_time": "2021-07-31T21:07:23.726527",
          "status": "completed"
        },
        "tags": [],
        "colab": {
          "base_uri": "https://127.0.0.1:8080/",
          "height": 675
        },
        "id": "13abf542",
        "outputId": "42a52c37-29f9-4f43-a70f-c285886f0c10"
      },
      "source": [
        "trainer = Trainer(\n",
        "    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
        "    args=training_args,                  # training arguments, defined above\n",
        "    train_dataset=train_dataset,         # training dataset\n",
        "    eval_dataset=val_dataset             # evaluation dataset\n",
        ")\n",
        "trainer.train()"
      ],
      "id": "13abf542",
      "execution_count": null,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "***** Running training *****\n",
            "  Num examples = 5293\n",
            "  Num Epochs = 50\n",
            "  Instantaneous batch size per device = 64\n",
            "  Total train batch size (w. parallel, distributed & accumulation) = 64\n",
            "  Gradient Accumulation steps = 1\n",
            "  Total optimization steps = 4150\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='408' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [ 408/4150 16:23 < 2:31:08, 0.41 it/s, Epoch 4.90/50]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1507' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1507/4150 1:01:13 < 1:47:31, 0.41 it/s, Epoch 18.14/50]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>500</td>\n",
              "      <td>0.243900</td>\n",
              "      <td>0.221590</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1000</td>\n",
              "      <td>0.200800</td>\n",
              "      <td>0.217612</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1500</td>\n",
              "      <td>0.198900</td>\n",
              "      <td>0.218175</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "***** Running Evaluation *****\n",
            "  Num examples = 279\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-500\n",
            "Configuration saved in ./results/checkpoint-500/config.json\n",
            "Model weights saved in ./results/checkpoint-500/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 279\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1000\n",
            "Configuration saved in ./results/checkpoint-1000/config.json\n",
            "Model weights saved in ./results/checkpoint-1000/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 279\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1500\n",
            "Configuration saved in ./results/checkpoint-1500/config.json\n",
            "Model weights saved in ./results/checkpoint-1500/pytorch_model.bin\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lsoxR09ymdHv"
      },
      "source": [
        ""
      ],
      "id": "lsoxR09ymdHv",
      "execution_count": null,
      "outputs": []
    }
  ]
}
关于此算法

# 使用 BERT 进行文本分类

这是一个使用 BERT 的文本分类管道基本实现。使用的 BERT 模型来自 huggingface。使用的 dataset 是一个包含两个类别(标记为 0 和 1)的自定义 dataset。它公开可用 此处

!pip install transformers
Collecting transformers
  Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)
     |████████████████████████████████| 2.8 MB 4.0 MB/s 
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)
Collecting huggingface-hub&amp;gt;=0.0.12
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
     |████████████████████████████████| 50 kB 6.8 MB/s 
Requirement already satisfied: numpy&amp;gt;=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)
Collecting pyyaml&amp;gt;=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
     |████████████████████████████████| 636 kB 51.9 MB/s 
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)
Collecting tokenizers&amp;lt;0.11,&amp;gt;=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
     |████████████████████████████████| 3.3 MB 21.4 MB/s 
Requirement already satisfied: tqdm&amp;gt;=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
     |████████████████████████████████| 895 kB 45.3 MB/s 
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub&amp;gt;=0.0.12-&amp;gt;transformers) (3.7.4.3)
Requirement already satisfied: pyparsing&amp;gt;=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging-&amp;gt;transformers) (2.4.7)
Requirement already satisfied: zipp&amp;gt;=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata-&amp;gt;transformers) (3.5.0)
Requirement already satisfied: chardet&amp;lt;4,&amp;gt;=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests-&amp;gt;transformers) (3.0.4)
Requirement already satisfied: certifi&amp;gt;=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests-&amp;gt;transformers) (2021.5.30)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,&amp;lt;1.26,&amp;gt;=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests-&amp;gt;transformers) (1.24.3)
Requirement already satisfied: idna&amp;lt;3,&amp;gt;=2.5 in /usr/local/lib/python3.7/dist-packages (from requests-&amp;gt;transformers) (2.10)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses-&amp;gt;transformers) (1.15.0)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses-&amp;gt;transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses-&amp;gt;transformers) (1.0.1)
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed huggingface-hub-0.0.16 pyyaml-5.4.1 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.10.0
import csv
import pickle
import pandas as pd
import numpy as np
train = pd.read_csv("spamdata_v2.csv")
print(len(train))
train.head()
5572
标签 文本
0 0 去往裕廊坊,疯了.. 仅限..
1 0 好吧.. 跟你说笑而已..
2 1 每周两次比赛免费入场,赢取足总杯决赛..
3 0 别那么早说.. 你看了再说..
4 0 不,我认为他不去 usf,他住在周围..
num_classes = 2
from sklearn.model_selection import train_test_split
train_split, val_split = train_test_split(train, test_size=.05)
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.df = df
        self.text = df.text.values
        self.labels = df.label.values
        self.tokenizer = tokenizer
        self.max_length = max_length
 
    def __getitem__(self, idx):
        
        tokenized_data = tokenizer.tokenize(self.text[idx])
        to_append = ["[CLS]"] + tokenized_data[:self.max_length - 2] + ["[SEP]"]
        input_ids = tokenizer.convert_tokens_to_ids(to_append)
        input_mask = [1] * len(input_ids)
        padding = [0] * (self.max_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        item = {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(input_mask, dtype=torch.long)
        }
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item
 
    def __len__(self):
        return len(self.df)

train_dataset = Dataset(train_split.fillna(""), tokenizer)
val_dataset = Dataset(val_split.fillna(""), tokenizer)
# train_dataset = Dataset(train.fillna(""), tokenizer, is_train=True, label_map=label_map)
# test_dataset = Dataset(test.fillna(""), tokenizer, is_train=False)
from transformers import BertForSequenceClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=50,              # total number of training epochs
    per_device_train_batch_size=64, # batch size per device during training
    per_device_eval_batch_size=64,  # batch size for evaluation
    warmup_steps=500,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,
    dataloader_num_workers=2,
    report_to="tensorboard",
    label_smoothing_factor=0.1,
    evaluation_strategy="steps",
    eval_steps=500, # Evaluation and Save happens every 500 steps
    save_total_limit=3, # Only last 5 models are saved. Older ones are deleted.
    load_best_model_at_end=True,   #best model is always saved
)

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
model.classifier = torch.nn.Linear(768, num_classes)
model.num_labels = num_classes
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: [&#x27;cls.predictions.transform.dense.bias&#x27;, &#x27;cls.predictions.bias&#x27;, &#x27;cls.predictions.transform.LayerNorm.weight&#x27;, &#x27;cls.seq_relationship.weight&#x27;, &#x27;cls.predictions.transform.LayerNorm.bias&#x27;, &#x27;cls.seq_relationship.bias&#x27;, &#x27;cls.predictions.decoder.weight&#x27;, &#x27;cls.predictions.transform.dense.weight&#x27;]
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: [&#x27;classifier.weight&#x27;, &#x27;classifier.bias&#x27;]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)
trainer.train()
***** Running training *****
  Num examples = 5293
  Num Epochs = 50
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed &amp; accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 4150
[ 408/4150 16:23 < 2:31:08, 0.41 it/s, Epoch 4.90/50]
步骤 训练损失 验证损失

[1507/4150 1:01:13 < 1:47:31, 0.41 it/s, Epoch 18.14/50]
步骤 训练损失 验证损失
500 0.243900 0.221590
1000 0.200800 0.217612
1500 0.198900 0.218175

***** Running Evaluation *****
  Num examples = 279
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-500
Configuration saved in ./results/checkpoint-500/config.json
Model weights saved in ./results/checkpoint-500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 279
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-1000
Configuration saved in ./results/checkpoint-1000/config.json
Model weights saved in ./results/checkpoint-1000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 279
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-1500
Configuration saved in ./results/checkpoint-1500/config.json
Model weights saved in ./results/checkpoint-1500/pytorch_model.bin