Chapter 14: Supervised Classification using BERT

Supervised methods that are based on the Bag of Words hypothesis (see chapter 11) work well, but these days we can do better. In this script, I will run you through how to use BERT (Bidirectional Encoder Representations from Transformers) for text classification. We’ll be working with a sentiment analysis task using the IMDb movie reviews data set, where we’ll classify reviews as either positive or negative.

Unlike traditional bag-of-words approaches, BERT understands context and nuance in language by considering the full context of a word by looking at the words that come before and after it. This allows it to capture more complex patterns in text, leading to better classification performance.

Setup

Here, we specifically use a virtual environment named 'transformer_env' which needs to contain all necessary Python packages. This isolation ensures reproducibility and prevents package conflicts. Make sure this virtual environment has all required packages (torch, transformers, etc.) installed – if not, install them via either the terminal or reticulate. While in the chapter on selenium we used reticulate functions to set up a (conda) environment, here we do it in the terminal (so that you can use the code on the server later).

#first, navigate to the folder you want the environment to live in, by using `cd path/to/folder`
python -m venv _pyenv/transformer_env #create environment

#activate
## for mac/linux
source _pyenv/transformer_env/bin/activate
# for windows
# _pyenv/transformer_env/Scripts/activate

# package installation
pip install --upgrade pip

# Install PyTorch for Silicon Macs (only if you have one)
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

#otherwise: 
pip install torch

# install remaining required packages
pip install transformers
pip install pandas numpy sklearn
pip install tqdm
pip install seaborn matplotlib
import seaborn as sns
import matplotlib.pyplot as plt

Then we can activate Python and the environment from within R (obviously, you can skip this step when you work in Python, e.g., in JupyterLab)

needs(reticulate)

use_virtualenv("_pyenv/transformer_env")

Then, we import all necessary libraries and set up our device configuration. The device setup is particularly important as it allows our code to run efficiently on different hardware configurations – whether that’s a Silicon Mac using MPS (in my case), a machine with CUDA-enabled GPU, or a regular CPU. Depending on this, you might have to install the respective torch packages. We use pandas for data manipulation, torch (PyTorch) for deep learning operations, and the transformers library for access to pre-trained BERT models. Furthermore, we use an array of sklearn (scikit-learn) functions for train-test split creation and subsequent model evaluation.

import torch
from torch import nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def get_device():
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device

device = get_device()
print(f"Using device: {device}")
Using device: mps

Similar to our previous supervised learning examples, we need to prepare our data in a format suitable for the model. Hence, this chunk defines our custom Dataset class for handling text data preparation. It converts our raw text and labels into BERT’s expected format. It handles tokenization using BERT’s specialized tokenizer, ensures all sequences are of the same length through padding or truncation (controlled by max_len parameter), and generates attention masks to properly handle variable-length inputs. All this information is converted into PyTorch tensors for model training.

class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

Then we define our model architecture by extending PyTorch’s Module class. The classifier builds upon the pre-trained BERT model (transfer learning). To prevent overfitting, we add a dropout layer for regularization with a default rate of 0.1. The final linear layer performs the actual classification, converting BERT’s 768-dimensional output into our desired number of classes (i.e., 2 here).

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        return self.classifier(pooled_output)

The training function implements our training loop with both training and validation phases. It handles device placement automatically (supporting CPU, CUDA, or MPS – for Silicon Macs). During training, it performs forward passes through the model, calculates loss, and updates the model’s parameters. The validation phase tracks the model’s performance on unseen data and the progress bar provides feedback during the training process.

def train_model(model, train_loader, val_loader, epochs=3, lr=2e-5):
    device = get_device()
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        # Disable the progress bar but keep the iteration
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{epochs}', disable=True)
        for batch in train_pbar:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        # Disable the progress bar but keep the iteration
        val_pbar = tqdm(val_loader, desc=f'Validating Epoch {epoch+1}/{epochs}', disable=True)
        with torch.no_grad():
            for batch in val_pbar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f'Epoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}')
        print(f'Val Accuracy: {100*correct/total:.2f}%\n')

Finally, the predict function handles inference on new texts. It manages the complete pipeline from raw text to final prediction: tokenizing the new, unseen input text using BERT’s tokenizer, moving the processed input to the appropriate device, running it through the model, and converting the model’s output into a prediction.

def predict(model, text, tokenizer):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        _, predicted = torch.max(outputs, 1)
    
    return predicted.item()

The Full Process

First, we need to load and prepare our IMDb dataset for training. Here, we are using a two-stage split process. First, we separate our test data (90% of our data, given it’s 25,000 examples, but later only 100 instances are used for speed), then create training and validation sets from the remaining data. We ensure balanced class distribution through stratified sampling. Then, the data is then processed into our custom Dataset format and wrapped in DataLoader objects, which handle batching and shuffling during training. The label mapping converts our text labels into the numeric format required for classification.

imdb_reviews = pd.read_csv("files/imdb_reviews.csv")

# separate test set (90% of data)
train_val_df, test_df = train_test_split(
    imdb_reviews,
    test_size=0.9,
    stratify=imdb_reviews['sentiment'],
    random_state=42
)
# separate train and validation (80/20 split of remaining data)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.2,
    stratify=train_val_df['sentiment'],
    random_state=42
)
    
# Create feature/label pairs
X_train = train_df['text'].tolist()
y_train = train_df['sentiment'].tolist()
    
X_val = val_df['text'].tolist()
y_val = val_df['sentiment'].tolist()

## we use a small test set in this example, only the first 100 instances
X_test = test_df['text'][0:100].tolist()
y_test = test_df['sentiment'][0:100].tolist()

# create label mapping to change labels to integers
label_map = {'negative': 0, 'positive': 1}
y_train = [label_map[label] for label in y_train]
y_val = [label_map[label] for label in y_val]
y_test = [label_map[label] for label in y_test]

Once the data preparation is finished, we can initialize the tokenizer and model, prepare the data loaders, and start training our model.

#initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertClassifier()
    
# load data
train_dataset = SentenceDataset(X_train, y_train, tokenizer)
val_dataset = SentenceDataset(X_val, y_val, tokenizer)
    
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

train_model(model, train_loader, val_loader)
Epoch 1:
Train Loss: 0.5076
Val Loss: 0.3161
Val Accuracy: 87.80%

Epoch 2:
Train Loss: 0.2887
Val Loss: 0.3226
Val Accuracy: 87.20%

Epoch 3:
Train Loss: 0.1381
Val Loss: 0.3791
Val Accuracy: 85.60%

Once the model has been trained, we can eyeball results.

predict(model, "this is a hell of a movie", tokenizer)
1
predict(model, "this movie is hell", tokenizer)
0

And do more vigorous evaluation on the held-out test set:

df = pd.DataFrame()
df['text'] = X_test 
df['label'] = y_test
df['prediction'] = [predict(model, text, tokenizer) for text in X_test]

metrics = {
    'Accuracy': accuracy_score(df['label'], df['prediction']),
    'Precision': precision_score(df['label'], df['prediction'], average='weighted'),
    'Recall': recall_score(df['label'], df['prediction'], average='weighted'),
    'F1 Score': f1_score(df['label'], df['prediction'], average='weighted')
}

print("\nEvaluation Metrics:")

Evaluation Metrics:
for metric, value in metrics.items():
    print(f"{metric}: {value:.3f}")
Accuracy: 0.830
Precision: 0.832
Recall: 0.830
F1 Score: 0.830