24 Text Classification How to Build a Text Classification Model Using Bert

24 Text Classification How to Build a Text Classification Model Using BERT #

Hello, I’m Fang Yuan.

In Lesson 22, we learned a lot about theoretical aspects of text processing. Text classification is actually widely used in the field of machine learning.

For example, imagine you’re an NLP research and development engineer and your boss suddenly hands you a large amount of news text data. These texts may come from different fields, such as sports, politics, economics, or society. In this case, we need to process text classification in order to facilitate users to quickly search for content they are interested in, or even recommend specific types of content based on user needs.

This kind of requirement is very suitable for PyTorch + BERT processing. Why choose BERT? Because BERT is a relatively typical deep learning NLP algorithm model, and it is also one of the most widely used models in the industry. Next, let’s build this text classification model together. Trust me, its performance is very powerful.

Background and Analysis of the Problem #

Before we start, let’s review the history. There are many classic solutions to text classification problems.

It all started with the simplest and most straightforward keyword statistics method. Then came the classification method based on Bayesian probability, which infers the probability of a certain category based on the probability of certain conditions and serves as the decision basis for the final classification. Although the idea is simple, its significance is significant, and Bayesian methods are still a good choice in many application scenarios today.

Then there is Support Vector Machine (SVM), which has dominated the field of NLP algorithm applications for a long time with its variants and applications.

With the improvement of computing device performance and the emergence of new algorithm theories, a large number of methods such as Random Forest, LDA Topic Model, and Neural Networks have emerged, creating a great diversity of options.

Since there are so many methods available, why do we recommend using BERT here?

Because in many cases, especially in complex text scenarios, powerful tools like BERT are needed. For example, news articles are not easy to classify due to the following issues.

  1. Multiple Categories. In news information apps, there are numerous categories of news articles. The product manager needs to design the classification system of articles based on statistical and practical principles, ensuring that all texts can be covered. Generally, there are 50 or more categories. However, to focus on the key points, let’s assume that the classification system has already been determined.

  2. Imbalanced Data. It is not difficult to understand that in news articles, categories such as society, economy, sports, and entertainment have relatively higher numbers, accounting for a large proportion. On the other hand, categories like children and medical are relatively few, sometimes having no corresponding articles for a whole day.

  3. Multilingualism. Generally speaking, besides Chinese, English is probably the only language that most people are proficient in. However, to consider the wide range of news sources, let’s assume that these texts are multilingual.

As mentioned earlier, BERT is a typical deep learning NLP algorithm model and one of the most widely used models in the industry. If you can master this representative model, you will also be able to apply other attention-based models, such as GPT, when learning and using them in the future.

To make good use of BERT, we need to first understand its characteristics.

Analysis of BERT Principles and Features #

The full name of BERT is Bidirectional Encoder Representation from Transformers, which is a bidirectional transformer encoder. As a model based on the Attention mechanism, it initially gained a lot of attention and achieved historical best results in more than a dozen NLP tasks such as text classification, automatic dialogue, and semantic understanding.

In Lesson 22 (if you are not familiar, you can review it), we already learned about the basic principles of Attention. With this knowledge as a foundation, we can quickly understand the principles of BERT.

Here’s a quick recap: BERT’s theoretical framework is mainly based on the Transformer proposed in the paper “Attention is all you need,” and the principle of the Transformer is the Attention mechanism mentioned earlier. Its most obvious feature is the abandonment of traditional RNN and CNN logic, effectively solving the problem of long-term dependencies in NLP.

Image

In BERT, its input part, which is the left side of the image, is actually composed of N multiple attention heads. Multi-head Attention divides the model into multiple heads, forming multiple subspaces, allowing the model to focus on different aspects of information, which helps the network capture more diverse features or information. (For detailed principles, be sure to refer to Attention Is All You Need).

From the above image, we should note that BERT adopts the model training method based on MLM, which stands for Masked Language Model. Because BERT is part of Transformer, the encoder part, it does not have the decoder part (which is actually GPT).

To solve this problem, the MLM method was introduced. Its idea is also very simple, which is to randomly mask (i.e., replace with a special token) a portion of the words (tokens) in the text before training, and then during the training process, use the other unmasked tokens to predict the masked tokens.

Image

Those who have used Word2Vec may know that, for the same word, its vector representation is fixed. This is why we have the classic calculation “king - man + woman = queen.”

However, there is a problem: the word “apple” can refer to both a fruit and an electronic product brand. If we use the same vector representation, it may introduce bias. In contrast, in BERT, the word vector representation for the same token varies dynamically according to the context, making it more flexible.

In addition, BERT has the advantage of multilingualism. In previous algorithms, such as SVM, if you want to build a multilingual model, you would need to deal with word segmentation, keyword extraction, and other operations, which require knowledge of the language. For languages like Arabic and Japanese, which we are most likely unable to understand, this would have a significant impact on the final model performance.

BERT, on the other hand, does not have to worry about this problem. By covering tokens at different levels, such as characters, character fragments, and words, based on WordPiece, it can cover hundreds of languages. In fact, as long as you can invent a logically consistent language, BERT can handle it. For more information about WordPiece, you can explore it here.

Alright, after saying so much, BERT, which combines efficiency, accuracy, flexibility, and versatility, naturally becomes our first choice. Now, let’s start building a text classification model.

Installation and Preparation #

To do a good job, one must have the right tools. Before we start building the model, we need to install the necessary tools, download the pre-trained model, and understand the data format.

Environment Setup #

Since we are using a BERT model based on PyTorch, we need to install the corresponding Python package. Here, I chose the PyTorch version of the Transformers package from Hugging Face. You can install it using the following pip command.

pip install Transformers

Model Setup #

After installation, we open the GitHub page for Transformers and navigate to the following folder:

src/Transformers/models/BERT

From this folder, we need to find two important files: convert_BERT_original_tf2_checkpoint_to_PyTorch.py and modeling_BERT.py.

Let’s start with the first file. Based on its name, can you guess what it is for? Correct! It is used to convert a BERT model pre-trained with TensorFlow into a PyTorch model.

Next is the modeling_BERT.py file, which provides an example of using BERT.

Now, let’s prepare the model. Open this link. In this page, you will find several pre-trained models.

Image

Considering the task in this lesson, we will choose the “BERT-Base, Multilingual Cased” version. According to the GitHub description, this checkpoint supports 104 languages. Impressive, isn’t it? Of course, if you don’t need multi-language support, you can choose other versions. The main difference between them is the network size.

After converting the model, you will find three new files in your local directory: config.json, pytorch_model.bin, and vocab.txt. Let me explain each of them.

Image

  1. config.json: As the name suggests, this file is the configuration file for the BERT model, which includes all the training parameters.

  2. pytorch_model.bin: The actual model file.

  3. vocab.txt: The vocabulary file. Although BERT can handle more than 100 languages, it still needs a vocabulary file to identify characters, strings, or words supported by the model.

Format Preparation #

Now that the model is ready, let’s take a look at the input format needed by the model. BERT’s input is not complicated, but it is important to understand its structure. During training, we don’t directly input the words into the model. Instead, we convert them into the following three types of vectors.

  1. Token embeddings: Word embeddings. Please note that the first token of token embeddings must be “[CLS]”. “[CLS]” serves as the semantic representation of the entire text and is used for tasks like text classification.

  2. Segment embeddings: This vector is mainly used to distinguish between two sentences. For example, in question-answering tasks, both the question and the answer are input simultaneously, requiring an operation that can differentiate between the two sentences. However, in our classification task this time, there is only one sentence.

  3. Position embeddings: Records the position information of words.

Model Building #

The preparation work is all done, and now let’s build a text classification network model based on BERT. This includes network design, configuration, and data preparation, which is the core process.

Network Design #

From the “modeling_BERT.py” file mentioned above, we can see that the author has actually provided us with various examples of NLP tasks. Let’s find the “BERTForSequenceClassification” among them. This classification network can be used directly, and it is the most basic process for text classification using BERT.

This process includes two parts: obtaining the embedding representation of the text using BERT and feeding the embedding into a fully connected layer to obtain the classification result. Let’s take a closer look at the code.

class BERTForSequenceClassification(BERTPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels// Number of category labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)// Do you remember what Dropout is used for? Yes, it can prevent overfitting to some extent.
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)// The embedding output by BERT is passed through an MLP layer for classification.
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]// This is the intermediate output obtained by BERT.

        pooled_output = self.dropout(pooled_output)// Yes, it is to reduce overfitting and increase the robustness of the network.
        logits = self.classifier(pooled_output)// The final classification result is output by a multi-layer MLP.

By comparing the above code with the previous code, we can see that after receiving the input information, BERT returns an outputs, which includes all the results of the model calculation, not only the information of each token, but also the information of the entire text. This output specifically includes the following information.

last_hidden_state is the hidden state sequence output by the last layer of the model. The shape is (batch_size, sequence_length, hidden_size). Among them, hidden_size=768. This part of the state is equivalent to a sequence_length x 768 matrix, which records the resulting information of each token after the calculation of the entire text.

pooled_output represents the hidden state of the last layer of the first token in the sequence. The shape is (batch_size, hidden_size). The so-called first token is the [CLS] label mentioned earlier.

In addition to the above two pieces of information, there are hidden_states, attentions, and cross attentions. Interested friends can look up their uses.

In typical tasks, we usually use the information corresponding to last_hidden_state, and we can obtain it by pooled_output = outputs[1].

So far, we have obtained the text vector representation calculated by BERT, and then we input it into a linear layer for classification to obtain the final classification result. To improve the performance of the model, we often add a dropout layer before the linear layer, which can reduce the possibility of overfitting and enhance the independence of neurons.

Model Configuration #

After designing the network, we need to configure the model. Do you remember the config.json file mentioned earlier? It contains all the configuration information required by the BERT model. We need to adjust several fields so that the model knows what we are going to do.

Let me explain these fields.

  • id2label: This field records the mapping between category labels and category names.
  • label2id: This field records the mapping between category names and category labels.
  • num_labels_cate: The number of categories.

Data Preparation #

The model network has been designed, and the configuration file has been set. Now, we are going to start with the data preparation step. Here, data preparation refers to converting the text into the three types of vectors that BERT can recognize, namely input_ids, token_type_ids, and attention_mask, as mentioned earlier.

To generate these data, we need to find the file “src/Transformers/data/processors/utils.py” in the git repository. In this file, we will need the following contents.

  1. InputExample: It is used to record the structure of the text content for a single training data.

  2. DataProcessor: Through the functions in this class, we can represent the text of the training dataset as a collection of multiple InputExamples.

  3. get_features: This is a key function for transforming the InputExample data into a data structure that BERT can understand. Let’s take a look at how each data is generated.

The input_ids record the id of each input token corresponding to the vocab.txt. It is obtained through the following code:

input_ids = tokenizer.encode(
    example.text_a,
    add_special_tokens=True,
    max_length=min(max_length, tokenizer.max_len),
)

The attention_mask records the token information belonging to the first sentence. It is obtained through the following code:

attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)

Additionally, don’t forget to record the information about the text category (label). Can you build the corresponding label information according to the declaration in the utils.py file?

Model Training #

So far, we have the network structure defined (BERTForSequenceClassification) and the dataset (get_features). Now we can start writing the code to implement the training process.

Selecting an Optimizer #

First, let’s select an optimizer. The code is as follows. We need to set all the weight parameters in the network so that the optimizer knows which parameters to optimize. Then we put the parameter list into the optimizer. BERT uses the AdamW optimizer.

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

This code is mainly for selecting an optimizer suitable for our model and setting the learning rate for the parameters in the network.

Building the Training Process Logic #

The training process logic is very simple. It only requires two for loops, representing the epoch and the batch, respectively. Then, inside the innermost loop, we add a training core statement and a gradient update statement. This is enough. As you can see, PyTorch provides a very concise and complete implementation for engineering code.

for epoch in trange(0, args.num_train_epochs):
      model.train() # Don't forget to set the model to training mode.
      for step, batch in enumerate(tqdm(train_dataLoader, desc='Iteration')):
        step_loss = training_step(batch) # The core of training
        tr_loss += step_loss[0]
        optimizer.step()
        optimizer.zero_grad()

The Core of Training #

In the core of training, you need to pay attention to two parts: obtaining the predicted output through the network (logits) and calculating the loss based on the logits. The loss is the data needed for the whole model to update gradients.

def training_step(batch):
      input_ids, token_type_ids, attention_mask, labels = batch
      input_ids = input_ids.to(device) # Send the data to the GPU
      token_type_ids = token_type_ids.to(device)
      attention_mask = attention_mask.to(device)
      labels = labels_voc.to(device)
      
      logits = model(input_ids,
            token_type_ids=token_type_ids,  
            attention_mask=attention_mask,  
            labels=labels)
      loss_fct = BCEWithLogitsLoss()
      loss = loss_fct(logits.view(-1, num_labels_cate), labels.view(-1, num_labels_cate).float())
      loss.backward()

So far, we have quickly built all the key code for a BERT classifier. However, there are still some small details that need to be completed, such as how to obtain the device in the training_step code block. Review the content we have learned before, and I believe you can do it.

Summary #

Congratulations on completing this lesson! Although there are already many well-encapsulated BERT code repositories on GitHub, and you can quickly implement a basic NLP algorithm flow, I still hope you can take some time to carefully study the model code in the Transformer. This will greatly improve your technical skills.

In this lesson, we learned how to quickly build a basic text classification model using PyTorch. To implement this process, you need to understand how to obtain and transform the pre-trained BERT model, design the classification network, and write the training process. The whole process is not difficult, but it allows you to quickly get started and understand how PyTorch is applied in NLP.

In addition to the technical aspects, we also need to pay attention to business considerations. For example, multi-language issues in news text, data imbalance, etc. Sometimes the model cannot solve all the problems, so you also need to learn some data preprocessing techniques, which include many technical and algorithmic aspects.

Even if I list a long learning checklist, it is still possible to miss many things, so I recommend that you focus on the following areas of knowledge in data preprocessing: I suggest you spend some time learning how to use NumPy and Pandas so that you can handle data more proficiently. You can also learn more about common data mining algorithms (such as decision trees, KNN, support vector machines, etc.). In addition, despite the wide use of deep learning, it still greatly relies on the support of traditional machine learning algorithms, so I also recommend that you learn more about them.

Thought Question #

BERT has a maximum length requirement (512) when processing text. So, what should be done when encountering long texts?

Feel free to leave your questions or insights in the comments section. It is also recommended to share this lesson with your friends.