Chapter 17: Supervised Classification using BERT

Supervised methods that are based on the Bag of Words hypothesis (see chapter 11) work well, but these days we can do better. In this script, I will run you through how to use BERT (Bidirectional Encoder Representations from Transformers) for text classification. We’ll be working with a sentiment analysis task using the IMDb movie reviews data set, where we’ll classify reviews as either positive or negative.

Unlike traditional bag-of-words approaches, BERT understands context and nuance in language by considering the full context of a word by looking at the words that come before and after it. This allows it to capture more complex patterns in text, leading to better classification performance.

Setup

Here, we specifically use a virtual environment named 'transformer_env' which needs to contain all necessary Python packages. We can activate Python and the environment from within R (obviously, you can skip this step when you work in Python, e.g., in JupyterLab)

needs(reticulate)

conda_create(
    "bert-env", 
    python_version = "3.9",
    channel = c("pytorch", "conda-forge"),
    packages = c("pip", "transformers", "pandas", "jpeg", "numpy", "scikit-learn", 
                 "tqdm", "seaborn", "matplotlib", "pytorch", "torchvision", "torchaudio")
)

use_condaenv("bert-env", required = TRUE)

Then, we import all necessary libraries and set up our device configuration. The device setup is particularly important as it allows our code to run efficiently on different hardware configurations – whether that’s a Silicon Mac using MPS (in my case), a machine with CUDA-enabled GPU, or a regular CPU. Depending on this, you might have to install the respective torch packages. We use pandas for data manipulation, torch (PyTorch) for deep learning operations, and the transformers library for access to pre-trained BERT models. Furthermore, we use an array of sklearn (scikit-learn) functions for train-test split creation and subsequent model evaluation.

import torch
from torch import nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

def get_device():
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device

device = get_device()
print(f"Using device: {device}")
Using device: mps

Similar to our previous supervised learning examples, we need to prepare our data in a format suitable for the model. Hence, this chunk defines our custom Dataset class for handling text data preparation. It converts our raw text and labels into BERT’s expected format. It handles tokenization using BERT’s specialized tokenizer, ensures all sequences are of the same length through padding or truncation (controlled by max_len parameter), and generates attention masks to properly handle variable-length inputs. All this information is converted into PyTorch tensors for model training.

class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

Understanding the components:

  • max_len=128: This is a crucial hyperparameter. It determines the maximum sequence length. Longer sequences capture more context but increase memory usage and computation time. For movie reviews, 128 tokens is often sufficient for capturing the key sentiment, though longer reviews will be truncated. If you find that important information is being lost, consider increasing this to 256 or 512 (BERT’s maximum).

  • padding='max_length': BERT requires all inputs in a batch to be the same length. Padding adds special [PAD] tokens to shorter sequences. Without this, we couldn’t efficiently batch our data.

  • attention_mask: This binary mask tells BERT which tokens are real words (1) and which are padding (0). This prevents the model from attending to padding tokens, which would add noise to the representations.

  • input_ids: These are the numerical representations of our tokens. BERT’s vocabulary contains about 30,000 tokens, and each word (or subword) gets mapped to a unique integer.

The __getitem__ method is called by PyTorch’s DataLoader to retrieve individual samples during training. By implementing this as a class, we maintain clean separation between data processing and model training logic.

Then we define our model architecture by extending PyTorch’s Module class. The classifier builds upon the pre-trained BERT model (transfer learning). To prevent overfitting, we add a dropout layer for regularization with a default rate of 0.1. The final linear layer performs the actual classification, converting BERT’s 768-dimensional output into our desired number of classes (i.e., 2 here).

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased') #you can choose different models here
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, 2) #768 is the hidden size of BERT base, 2 is number of classes
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        return self.classifier(pooled_output)

Understanding the components:

  • bert-base-uncased: This is the base BERT model with 110 million parameters - these are the learnable weights that capture patterns in language. The “uncased” version converts all text to lowercase during preprocessing, which generally works well for sentiment analysis where “Good” and “good” should be treated the same. There’s also bert-base-cased which preserves case, useful when capitalization carries semantic meaning (e.g., “US” the country vs “us” the pronoun, or “Apple” the company vs “apple” the fruit).

  • Transfer Learning Strategy: We’re using BERT as a feature extractor. During training, BERT’s pre-trained weights will be fine-tuned (adjusted through backpropagation) on our specific task. Think of it like this: BERT has already learned general language understanding from billions of words - we’re now teaching it the specific task of sentiment classification. This is far more efficient than training from scratch, which would require massive amounts of labeled data and computational resources.

  • The 768-dimensional output: BERT-base produces a 768-dimensional vector representation (an embedding) for each input sequence. The outputs[1] extracts the representation of the special [CLS] (classification) token that BERT places at the start of every input. During pre-training, BERT learns to encode sequence-level information into this [CLS] token representation, making it ideal for classification tasks. This 768-dimensional vector captures semantic and contextual information about the entire review, which then gets passed to our classification head.

  • Dropout Rate: The default 0.1 means that during training, 10% of neurons are randomly “dropped out” (set to zero) on each forward pass. This is a regularization technique that prevents the network from relying too heavily on any specific features and forces it to learn more robust representations. The dropout is applied between the BERT output and the classification layer. If your model overfits (high training accuracy but low validation accuracy), try increasing this to 0.2 or 0.3. Note that dropout is automatically disabled during inference (when model.eval() is called).

  • Classification Head: The final linear layer is remarkably simple - it’s just a weight matrix that projects from 768 dimensions to 2 dimensions (one for each class). Mathematically, it performs: output = W × [CLS_embedding] + b, where W is a learned 768×2 matrix and b is a bias vector. The softmax activation (applied within the CrossEntropyLoss) then converts these 2 raw scores (logits) into probabilities that sum to 1. Despite this simplicity, this classification head is often all you need when BERT’s contextual representations are already highly informative.

The training function implements our training loop with both training and validation phases. It handles device placement automatically (supporting CPU, CUDA, or MPS – for Silicon Macs). During training, it performs forward passes through the model, calculates loss, and updates the model’s parameters. The validation phase tracks the model’s performance on unseen data and the progress bar provides feedback during the training process.

def train_model(model, train_loader, val_loader, epochs=3, lr=2e-5):
    device = get_device()
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{epochs}', disable=False)
        for batch in train_pbar:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        
        val_pbar = tqdm(val_loader, desc=f'Validating Epoch {epoch+1}/{epochs}', disable=False)
        with torch.no_grad():
            for batch in val_pbar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f'Epoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}')
        print(f'Val Accuracy: {100*correct/total:.2f}%\n')

Understanding the components:

Training Loop Mechanics:

  • AdamW Optimizer: This is Adam with weight decay, which helps prevent overfitting. The learning rate (lr=2e-5) is crucial – too high and training becomes unstable, too low and training is extremely slow. For BERT fine-tuning, values between 1e-5 and 5e-5 typically work best.

  • CrossEntropyLoss - An Intuitive Explanation:

Think of the loss as a “wrongness score” that tells us how far off our model’s predictions are from the correct answers. The loss measures how confident the model is in the wrong answer (or how uncertain it is about the right answer).

Here’s how it works intuitively:

Perfect prediction: If the model is 100% confident in the correct class, the loss is 0 (perfect score). If the True label: Positive (class 1), the model outputs: [0.01, 0.99] (1% negative, 99% positive), loss is 0.01. If the model is only 50% confident in the correct class, the loss is around -log(0.5)=0.69. If the model is confident in the WRONG class, the loss is very high ([0.99, 0.01], loss = -log(0.01) = ~4.6). The loss uses logarithms, which means it heavily penalizes confident mistakes. Being confidently wrong is much worse than being uncertain. This encourages the model to be honest about its uncertainty rather than making bold wrong guesses.

What you’ll see during training: - Initial loss: Often 0.6-0.9 (model is basically guessing) - Well-trained loss: 0.1-0.3 (model is confident and mostly correct) - Overfit model: Training loss near 0, but validation loss increases (memorizing, not learning)

The goal of training is to minimize this loss by adjusting the model’s weights. Lower loss = better predictions.

  • model.train() vs model.eval(): These modes affect layers like dropout and batch normalization. In training mode, dropout is active; in evaluation mode, it’s disabled. Always remember to set the correct mode!

  • optimizer.zero_grad(): Gradients accumulate by default in PyTorch, so we must zero them before each backward pass. Forgetting this is a common bug that leads to bizarre training behavior.

  • loss.backward(): This computes gradients of the loss with respect to all model parameters using backpropagation. It’s the magic of automatic differentiation.

  • optimizer.step(): This updates the model parameters based on the computed gradients.

  • Validation without gradients: The with torch.no_grad(): context saves memory and computation by not tracking gradients during validation. We’re not updating parameters here, so gradients aren’t needed.

Monitoring Training:

Watch for these patterns in your training output:

  • Normal training: Training loss decreases, validation loss decreases, validation accuracy increases.
    • Epoch 1: Train Loss ~0.45, Val Loss ~0.35, Val Acc ~85%
    • Epoch 2: Train Loss ~0.28, Val Loss ~0.25, Val Acc ~90%
    • Epoch 3: Train Loss ~0.18, Val Loss ~0.20, Val Acc ~92%
  • Overfitting: Training loss continues decreasing, but validation loss starts increasing. Consider more dropout, fewer epochs, or more training data.
    • Epoch 1: Train Loss ~0.45, Val Loss ~0.35
    • Epoch 2: Train Loss ~0.25, Val Loss ~0.30 (validation getting worse!)
    • Epoch 3: Train Loss ~0.12, Val Loss ~0.38 (memorizing training data)
  • Underfitting: Both losses remain high. Try training longer, reducing dropout, or using a larger model.
    • Epoch 1: Train Loss ~0.65, Val Loss ~0.68
    • Epoch 2: Train Loss ~0.62, Val Loss ~0.65 (barely improving)
    • Epoch 3: Train Loss ~0.60, Val Loss ~0.63 (model can’t learn the patterns)
  • Unstable training: Loss jumps around erratically. Reduce learning rate or check for data issues.
    • Epoch 1: Train Loss ~0.45
    • Epoch 2: Train Loss ~1.23 (suddenly much worse!)
    • Epoch 3: Train Loss ~0.31 (jumping around)

Interpreting Loss Values: - Loss > 0.69: Model is performing at or below random guessing (for binary classification) - Loss 0.3-0.6: Model is learning but has room for improvement - Loss 0.1-0.3: Model is performing well and making confident correct predictions - Loss < 0.1: Model might be overfitting (especially if validation loss is much higher)

Finally, the predict function handles inference on new texts. It manages the complete pipeline from raw text to final prediction: tokenizing the new, unseen input text using BERT’s tokenizer, moving the processed input to the appropriate device, running it through the model, and converting the model’s output into a prediction.

def predict(model, text, tokenizer):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        _, predicted = torch.max(outputs, 1)
    
    return predicted.item()

This function is what you’ll use in production after training. Note that it uses the same preprocessing pipeline (tokenization, padding, truncation) as during training – consistency here is critical. The model returns logits (raw scores) for each class, and torch.max identifies which class has the highest score.

The Full Process

Now that our functions are set up, we can proceed to the actual data loading, model training, and evaluation.

First, we need to load and prepare our IMDb dataset for training. Here, we are using a two-stage split process. First, we separate our test data (95% of our data, given it’s 25,000 examples, but later only 100 instances are used for speed), then create training and validation sets from the remaining data. We ensure balanced class distribution through stratified sampling. Then, the data is processed into our custom Dataset format and wrapped in DataLoader objects, which handle batching and shuffling during training. The label mapping converts our text labels into the numeric format required for classification.

imdb_reviews = pd.read_csv("files/imdb_reviews.csv")

# separate test set (95% of data)
train_val_df, test_df = train_test_split(
    imdb_reviews,
    test_size=0.95,
    stratify=imdb_reviews['sentiment'],
    random_state=1312
)
# separate train and validation (80/20 split of remaining data)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.2,
    stratify=train_val_df['sentiment'],
    random_state=1312
)
    
# Create feature/label pairs
X_train = train_df['text'].tolist()
y_train = train_df['sentiment'].tolist()
    
X_val = val_df['text'].tolist()
y_val = val_df['sentiment'].tolist()

## we use a small test set in this example, only the first 100 instances
X_test = test_df['text'][0:100].tolist()
y_test = test_df['sentiment'][0:100].tolist()

# create label mapping to change labels to integers
label_map = {'negative': 0, 'positive': 1}
y_train = [label_map[label] for label in y_train]
y_val = [label_map[label] for label in y_val]
y_test = [label_map[label] for label in y_test]

This demonstrates a realistic but resource-constrained scenario. The IMDb dataset contains 50,000 reviews, which is substantial but manageable. However, in this tutorial we’re only using 5% for training to speed up execution.

The stratified sampling (stratify=imdb_reviews['sentiment']) ensures that positive and negative reviews are proportionally represented in each split. This is crucial for imbalanced datasets. Without stratification, you might accidentally create a test set that’s 70% positive while your training set is 50/50, making evaluation misleading.

The 80/20 train/validation split is a common convention. The validation set serves as a sanity check during training, helping you catch overfitting early. The test set should only be used once at the very end to get an unbiased estimate of real-world performance.

random_state=42 ensures reproducibility – every time you run this code, you’ll get the same splits. This is essential for scientific reproducibility and debugging. Change this number to get different splits and check if your results are stable across different data divisions.

Once the data preparation is finished, we can initialize the tokenizer and model, prepare the data loaders, and start training our model.

#initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertClassifier()
    
# load data
train_dataset = SentenceDataset(X_train, y_train, tokenizer)
val_dataset = SentenceDataset(X_val, y_val, tokenizer)
    
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

train_model(model, train_loader, val_loader)

Training Epoch 1/3:   0%|          | 0/63 [00:00<?, ?it/s]
Training Epoch 1/3:   2%|1         | 1/63 [00:01<01:26,  1.40s/it]
Training Epoch 1/3:   3%|3         | 2/63 [00:02<01:01,  1.01s/it]
Training Epoch 1/3:   5%|4         | 3/63 [00:02<00:50,  1.20it/s]
Training Epoch 1/3:   6%|6         | 4/63 [00:03<00:44,  1.33it/s]
Training Epoch 1/3:   8%|7         | 5/63 [00:03<00:40,  1.43it/s]
Training Epoch 1/3:  10%|9         | 6/63 [00:04<00:38,  1.50it/s]
Training Epoch 1/3:  11%|#1        | 7/63 [00:05<00:36,  1.55it/s]
Training Epoch 1/3:  13%|#2        | 8/63 [00:05<00:35,  1.57it/s]
Training Epoch 1/3:  14%|#4        | 9/63 [00:06<00:33,  1.60it/s]
Training Epoch 1/3:  16%|#5        | 10/63 [00:07<00:33,  1.60it/s]
Training Epoch 1/3:  17%|#7        | 11/63 [00:07<00:32,  1.61it/s]
Training Epoch 1/3:  19%|#9        | 12/63 [00:08<00:31,  1.60it/s]
Training Epoch 1/3:  21%|##        | 13/63 [00:08<00:31,  1.61it/s]
Training Epoch 1/3:  22%|##2       | 14/63 [00:09<00:30,  1.62it/s]
Training Epoch 1/3:  24%|##3       | 15/63 [00:10<00:29,  1.63it/s]
Training Epoch 1/3:  25%|##5       | 16/63 [00:10<00:28,  1.64it/s]
Training Epoch 1/3:  27%|##6       | 17/63 [00:11<00:27,  1.64it/s]
Training Epoch 1/3:  29%|##8       | 18/63 [00:11<00:27,  1.63it/s]
Training Epoch 1/3:  30%|###       | 19/63 [00:12<00:26,  1.64it/s]
Training Epoch 1/3:  32%|###1      | 20/63 [00:13<00:26,  1.64it/s]
Training Epoch 1/3:  33%|###3      | 21/63 [00:13<00:25,  1.64it/s]
Training Epoch 1/3:  35%|###4      | 22/63 [00:14<00:25,  1.64it/s]
Training Epoch 1/3:  37%|###6      | 23/63 [00:14<00:24,  1.63it/s]
Training Epoch 1/3:  38%|###8      | 24/63 [00:15<00:24,  1.62it/s]
Training Epoch 1/3:  40%|###9      | 25/63 [00:16<00:23,  1.63it/s]
Training Epoch 1/3:  41%|####1     | 26/63 [00:16<00:22,  1.63it/s]
Training Epoch 1/3:  43%|####2     | 27/63 [00:17<00:22,  1.62it/s]
Training Epoch 1/3:  44%|####4     | 28/63 [00:18<00:21,  1.62it/s]
Training Epoch 1/3:  46%|####6     | 29/63 [00:18<00:20,  1.63it/s]
Training Epoch 1/3:  48%|####7     | 30/63 [00:19<00:20,  1.63it/s]
Training Epoch 1/3:  49%|####9     | 31/63 [00:19<00:19,  1.63it/s]
Training Epoch 1/3:  51%|#####     | 32/63 [00:20<00:19,  1.63it/s]
Training Epoch 1/3:  52%|#####2    | 33/63 [00:21<00:18,  1.63it/s]
Training Epoch 1/3:  54%|#####3    | 34/63 [00:21<00:17,  1.64it/s]
Training Epoch 1/3:  56%|#####5    | 35/63 [00:22<00:17,  1.64it/s]
Training Epoch 1/3:  57%|#####7    | 36/63 [00:22<00:16,  1.64it/s]
Training Epoch 1/3:  59%|#####8    | 37/63 [00:23<00:15,  1.63it/s]
Training Epoch 1/3:  60%|######    | 38/63 [00:24<00:15,  1.63it/s]
Training Epoch 1/3:  62%|######1   | 39/63 [00:24<00:14,  1.64it/s]
Training Epoch 1/3:  63%|######3   | 40/63 [00:25<00:13,  1.65it/s]
Training Epoch 1/3:  65%|######5   | 41/63 [00:25<00:13,  1.65it/s]
Training Epoch 1/3:  67%|######6   | 42/63 [00:26<00:12,  1.65it/s]
Training Epoch 1/3:  68%|######8   | 43/63 [00:27<00:12,  1.64it/s]
Training Epoch 1/3:  70%|######9   | 44/63 [00:27<00:11,  1.64it/s]
Training Epoch 1/3:  71%|#######1  | 45/63 [00:28<00:10,  1.64it/s]
Training Epoch 1/3:  73%|#######3  | 46/63 [00:29<00:10,  1.63it/s]
Training Epoch 1/3:  75%|#######4  | 47/63 [00:29<00:09,  1.63it/s]
Training Epoch 1/3:  76%|#######6  | 48/63 [00:30<00:09,  1.64it/s]
Training Epoch 1/3:  78%|#######7  | 49/63 [00:30<00:08,  1.64it/s]
Training Epoch 1/3:  79%|#######9  | 50/63 [00:31<00:07,  1.64it/s]
Training Epoch 1/3:  81%|########  | 51/63 [00:32<00:07,  1.64it/s]
Training Epoch 1/3:  83%|########2 | 52/63 [00:32<00:06,  1.64it/s]
Training Epoch 1/3:  84%|########4 | 53/63 [00:33<00:06,  1.64it/s]
Training Epoch 1/3:  86%|########5 | 54/63 [00:33<00:05,  1.64it/s]
Training Epoch 1/3:  87%|########7 | 55/63 [00:34<00:04,  1.63it/s]
Training Epoch 1/3:  89%|########8 | 56/63 [00:35<00:04,  1.63it/s]
Training Epoch 1/3:  90%|######### | 57/63 [00:35<00:03,  1.62it/s]
Training Epoch 1/3:  92%|#########2| 58/63 [00:36<00:03,  1.64it/s]
Training Epoch 1/3:  94%|#########3| 59/63 [00:37<00:02,  1.62it/s]
Training Epoch 1/3:  95%|#########5| 60/63 [00:37<00:01,  1.63it/s]
Training Epoch 1/3:  97%|#########6| 61/63 [00:38<00:01,  1.61it/s]
Training Epoch 1/3:  98%|#########8| 62/63 [00:38<00:00,  1.62it/s]
Training Epoch 1/3: 100%|##########| 63/63 [00:39<00:00,  1.73it/s]
Training Epoch 1/3: 100%|##########| 63/63 [00:39<00:00,  1.60it/s]

Validating Epoch 1/3:   0%|          | 0/16 [00:00<?, ?it/s]
Validating Epoch 1/3:   6%|6         | 1/16 [00:00<00:03,  4.49it/s]
Validating Epoch 1/3:  12%|#2        | 2/16 [00:00<00:02,  4.81it/s]
Validating Epoch 1/3:  19%|#8        | 3/16 [00:00<00:02,  4.76it/s]
Validating Epoch 1/3:  25%|##5       | 4/16 [00:00<00:02,  4.91it/s]
Validating Epoch 1/3:  31%|###1      | 5/16 [00:01<00:02,  4.88it/s]
Validating Epoch 1/3:  38%|###7      | 6/16 [00:01<00:02,  4.90it/s]
Validating Epoch 1/3:  44%|####3     | 7/16 [00:01<00:01,  4.99it/s]
Validating Epoch 1/3:  50%|#####     | 8/16 [00:01<00:01,  4.93it/s]
Validating Epoch 1/3:  56%|#####6    | 9/16 [00:01<00:01,  4.98it/s]
Validating Epoch 1/3:  62%|######2   | 10/16 [00:02<00:01,  4.91it/s]
Validating Epoch 1/3:  69%|######8   | 11/16 [00:02<00:01,  4.85it/s]
Validating Epoch 1/3:  75%|#######5  | 12/16 [00:02<00:00,  4.93it/s]
Validating Epoch 1/3:  81%|########1 | 13/16 [00:02<00:00,  4.92it/s]
Validating Epoch 1/3:  88%|########7 | 14/16 [00:02<00:00,  4.94it/s]
Validating Epoch 1/3:  94%|#########3| 15/16 [00:03<00:00,  4.79it/s]
Validating Epoch 1/3: 100%|##########| 16/16 [00:03<00:00,  4.64it/s]
Validating Epoch 1/3: 100%|##########| 16/16 [00:03<00:00,  4.83it/s]
Epoch 1:
Train Loss: 0.6698
Val Loss: 0.5987
Val Accuracy: 72.80%


Training Epoch 2/3:   0%|          | 0/63 [00:00<?, ?it/s]
Training Epoch 2/3:   2%|1         | 1/63 [00:00<00:39,  1.56it/s]
Training Epoch 2/3:   3%|3         | 2/63 [00:01<00:38,  1.60it/s]
Training Epoch 2/3:   5%|4         | 3/63 [00:01<00:36,  1.63it/s]
Training Epoch 2/3:   6%|6         | 4/63 [00:02<00:35,  1.64it/s]
Training Epoch 2/3:   8%|7         | 5/63 [00:03<00:35,  1.63it/s]
Training Epoch 2/3:  10%|9         | 6/63 [00:03<00:34,  1.65it/s]
Training Epoch 2/3:  11%|#1        | 7/63 [00:04<00:34,  1.64it/s]
Training Epoch 2/3:  13%|#2        | 8/63 [00:04<00:33,  1.64it/s]
Training Epoch 2/3:  14%|#4        | 9/63 [00:05<00:33,  1.62it/s]
Training Epoch 2/3:  16%|#5        | 10/63 [00:06<00:32,  1.64it/s]
Training Epoch 2/3:  17%|#7        | 11/63 [00:06<00:31,  1.63it/s]
Training Epoch 2/3:  19%|#9        | 12/63 [00:07<00:31,  1.62it/s]
Training Epoch 2/3:  21%|##        | 13/63 [00:07<00:30,  1.63it/s]
Training Epoch 2/3:  22%|##2       | 14/63 [00:08<00:29,  1.64it/s]
Training Epoch 2/3:  24%|##3       | 15/63 [00:09<00:29,  1.64it/s]
Training Epoch 2/3:  25%|##5       | 16/63 [00:09<00:28,  1.65it/s]
Training Epoch 2/3:  27%|##6       | 17/63 [00:10<00:28,  1.64it/s]
Training Epoch 2/3:  29%|##8       | 18/63 [00:11<00:27,  1.63it/s]
Training Epoch 2/3:  30%|###       | 19/63 [00:11<00:26,  1.64it/s]
Training Epoch 2/3:  32%|###1      | 20/63 [00:12<00:26,  1.62it/s]
Training Epoch 2/3:  33%|###3      | 21/63 [00:12<00:25,  1.63it/s]
Training Epoch 2/3:  35%|###4      | 22/63 [00:13<00:25,  1.61it/s]
Training Epoch 2/3:  37%|###6      | 23/63 [00:14<00:24,  1.62it/s]
Training Epoch 2/3:  38%|###8      | 24/63 [00:14<00:24,  1.61it/s]
Training Epoch 2/3:  40%|###9      | 25/63 [00:15<00:23,  1.63it/s]
Training Epoch 2/3:  41%|####1     | 26/63 [00:15<00:22,  1.63it/s]
Training Epoch 2/3:  43%|####2     | 27/63 [00:16<00:22,  1.63it/s]
Training Epoch 2/3:  44%|####4     | 28/63 [00:17<00:21,  1.64it/s]
Training Epoch 2/3:  46%|####6     | 29/63 [00:17<00:20,  1.63it/s]
Training Epoch 2/3:  48%|####7     | 30/63 [00:18<00:20,  1.64it/s]
Training Epoch 2/3:  49%|####9     | 31/63 [00:19<00:19,  1.63it/s]
Training Epoch 2/3:  51%|#####     | 32/63 [00:19<00:18,  1.64it/s]
Training Epoch 2/3:  52%|#####2    | 33/63 [00:20<00:18,  1.65it/s]
Training Epoch 2/3:  54%|#####3    | 34/63 [00:20<00:17,  1.63it/s]
Training Epoch 2/3:  56%|#####5    | 35/63 [00:21<00:17,  1.64it/s]
Training Epoch 2/3:  57%|#####7    | 36/63 [00:22<00:16,  1.65it/s]
Training Epoch 2/3:  59%|#####8    | 37/63 [00:22<00:15,  1.64it/s]
Training Epoch 2/3:  60%|######    | 38/63 [00:23<00:15,  1.64it/s]
Training Epoch 2/3:  62%|######1   | 39/63 [00:23<00:14,  1.64it/s]
Training Epoch 2/3:  63%|######3   | 40/63 [00:24<00:13,  1.65it/s]
Training Epoch 2/3:  65%|######5   | 41/63 [00:25<00:13,  1.63it/s]
Training Epoch 2/3:  67%|######6   | 42/63 [00:25<00:12,  1.64it/s]
Training Epoch 2/3:  68%|######8   | 43/63 [00:26<00:12,  1.64it/s]
Training Epoch 2/3:  70%|######9   | 44/63 [00:26<00:11,  1.62it/s]
Training Epoch 2/3:  71%|#######1  | 45/63 [00:27<00:11,  1.62it/s]
Training Epoch 2/3:  73%|#######3  | 46/63 [00:28<00:10,  1.63it/s]
Training Epoch 2/3:  75%|#######4  | 47/63 [00:28<00:09,  1.63it/s]
Training Epoch 2/3:  76%|#######6  | 48/63 [00:29<00:09,  1.63it/s]
Training Epoch 2/3:  78%|#######7  | 49/63 [00:30<00:08,  1.62it/s]
Training Epoch 2/3:  79%|#######9  | 50/63 [00:30<00:07,  1.63it/s]
Training Epoch 2/3:  81%|########  | 51/63 [00:31<00:07,  1.65it/s]
Training Epoch 2/3:  83%|########2 | 52/63 [00:31<00:06,  1.64it/s]
Training Epoch 2/3:  84%|########4 | 53/63 [00:32<00:06,  1.63it/s]
Training Epoch 2/3:  86%|########5 | 54/63 [00:33<00:05,  1.63it/s]
Training Epoch 2/3:  87%|########7 | 55/63 [00:33<00:04,  1.63it/s]
Training Epoch 2/3:  89%|########8 | 56/63 [00:34<00:04,  1.64it/s]
Training Epoch 2/3:  90%|######### | 57/63 [00:34<00:03,  1.64it/s]
Training Epoch 2/3:  92%|#########2| 58/63 [00:35<00:03,  1.63it/s]
Training Epoch 2/3:  94%|#########3| 59/63 [00:36<00:02,  1.64it/s]
Training Epoch 2/3:  95%|#########5| 60/63 [00:36<00:01,  1.64it/s]
Training Epoch 2/3:  97%|#########6| 61/63 [00:37<00:01,  1.64it/s]
Training Epoch 2/3:  98%|#########8| 62/63 [00:37<00:00,  1.65it/s]
Training Epoch 2/3: 100%|##########| 63/63 [00:38<00:00,  1.89it/s]
Training Epoch 2/3: 100%|##########| 63/63 [00:38<00:00,  1.65it/s]

Validating Epoch 2/3:   0%|          | 0/16 [00:00<?, ?it/s]
Validating Epoch 2/3:   6%|6         | 1/16 [00:00<00:02,  5.13it/s]
Validating Epoch 2/3:  12%|#2        | 2/16 [00:00<00:02,  5.05it/s]
Validating Epoch 2/3:  19%|#8        | 3/16 [00:00<00:02,  4.92it/s]
Validating Epoch 2/3:  25%|##5       | 4/16 [00:00<00:02,  5.03it/s]
Validating Epoch 2/3:  31%|###1      | 5/16 [00:01<00:02,  4.96it/s]
Validating Epoch 2/3:  38%|###7      | 6/16 [00:01<00:02,  4.97it/s]
Validating Epoch 2/3:  44%|####3     | 7/16 [00:01<00:01,  5.04it/s]
Validating Epoch 2/3:  50%|#####     | 8/16 [00:01<00:01,  4.97it/s]
Validating Epoch 2/3:  56%|#####6    | 9/16 [00:01<00:01,  5.03it/s]
Validating Epoch 2/3:  62%|######2   | 10/16 [00:02<00:01,  4.92it/s]
Validating Epoch 2/3:  69%|######8   | 11/16 [00:02<00:01,  4.87it/s]
Validating Epoch 2/3:  75%|#######5  | 12/16 [00:02<00:00,  4.94it/s]
Validating Epoch 2/3:  81%|########1 | 13/16 [00:02<00:00,  4.93it/s]
Validating Epoch 2/3:  88%|########7 | 14/16 [00:02<00:00,  5.00it/s]
Validating Epoch 2/3:  94%|#########3| 15/16 [00:03<00:00,  4.86it/s]
Validating Epoch 2/3: 100%|##########| 16/16 [00:03<00:00,  5.46it/s]
Validating Epoch 2/3: 100%|##########| 16/16 [00:03<00:00,  5.06it/s]
Epoch 2:
Train Loss: 0.3877
Val Loss: 0.3549
Val Accuracy: 86.00%


Training Epoch 3/3:   0%|          | 0/63 [00:00<?, ?it/s]
Training Epoch 3/3:   2%|1         | 1/63 [00:00<00:38,  1.60it/s]
Training Epoch 3/3:   3%|3         | 2/63 [00:01<00:37,  1.62it/s]
Training Epoch 3/3:   5%|4         | 3/63 [00:01<00:37,  1.62it/s]
Training Epoch 3/3:   6%|6         | 4/63 [00:02<00:36,  1.63it/s]
Training Epoch 3/3:   8%|7         | 5/63 [00:03<00:35,  1.63it/s]
Training Epoch 3/3:  10%|9         | 6/63 [00:03<00:34,  1.65it/s]
Training Epoch 3/3:  11%|#1        | 7/63 [00:04<00:34,  1.65it/s]
Training Epoch 3/3:  13%|#2        | 8/63 [00:04<00:33,  1.63it/s]
Training Epoch 3/3:  14%|#4        | 9/63 [00:05<00:32,  1.65it/s]
Training Epoch 3/3:  16%|#5        | 10/63 [00:06<00:32,  1.65it/s]
Training Epoch 3/3:  17%|#7        | 11/63 [00:06<00:31,  1.64it/s]
Training Epoch 3/3:  19%|#9        | 12/63 [00:07<00:31,  1.63it/s]
Training Epoch 3/3:  21%|##        | 13/63 [00:07<00:30,  1.64it/s]
Training Epoch 3/3:  22%|##2       | 14/63 [00:08<00:29,  1.63it/s]
Training Epoch 3/3:  24%|##3       | 15/63 [00:09<00:29,  1.63it/s]
Training Epoch 3/3:  25%|##5       | 16/63 [00:09<00:28,  1.65it/s]
Training Epoch 3/3:  27%|##6       | 17/63 [00:10<00:27,  1.64it/s]
Training Epoch 3/3:  29%|##8       | 18/63 [00:10<00:27,  1.64it/s]
Training Epoch 3/3:  30%|###       | 19/63 [00:11<00:26,  1.65it/s]
Training Epoch 3/3:  32%|###1      | 20/63 [00:12<00:26,  1.64it/s]
Training Epoch 3/3:  33%|###3      | 21/63 [00:12<00:25,  1.65it/s]
Training Epoch 3/3:  35%|###4      | 22/63 [00:13<00:24,  1.64it/s]
Training Epoch 3/3:  37%|###6      | 23/63 [00:14<00:24,  1.65it/s]
Training Epoch 3/3:  38%|###8      | 24/63 [00:14<00:23,  1.65it/s]
Training Epoch 3/3:  40%|###9      | 25/63 [00:15<00:23,  1.63it/s]
Training Epoch 3/3:  41%|####1     | 26/63 [00:15<00:22,  1.64it/s]
Training Epoch 3/3:  43%|####2     | 27/63 [00:16<00:22,  1.62it/s]
Training Epoch 3/3:  44%|####4     | 28/63 [00:17<00:21,  1.61it/s]
Training Epoch 3/3:  46%|####6     | 29/63 [00:17<00:20,  1.62it/s]
Training Epoch 3/3:  48%|####7     | 30/63 [00:18<00:20,  1.64it/s]
Training Epoch 3/3:  49%|####9     | 31/63 [00:18<00:19,  1.64it/s]
Training Epoch 3/3:  51%|#####     | 32/63 [00:19<00:18,  1.65it/s]
Training Epoch 3/3:  52%|#####2    | 33/63 [00:20<00:18,  1.65it/s]
Training Epoch 3/3:  54%|#####3    | 34/63 [00:20<00:17,  1.64it/s]
Training Epoch 3/3:  56%|#####5    | 35/63 [00:21<00:16,  1.65it/s]
Training Epoch 3/3:  57%|#####7    | 36/63 [00:21<00:16,  1.64it/s]
Training Epoch 3/3:  59%|#####8    | 37/63 [00:22<00:15,  1.64it/s]
Training Epoch 3/3:  60%|######    | 38/63 [00:23<00:15,  1.62it/s]
Training Epoch 3/3:  62%|######1   | 39/63 [00:23<00:14,  1.63it/s]
Training Epoch 3/3:  63%|######3   | 40/63 [00:24<00:14,  1.62it/s]
Training Epoch 3/3:  65%|######5   | 41/63 [00:25<00:13,  1.63it/s]
Training Epoch 3/3:  67%|######6   | 42/63 [00:25<00:12,  1.63it/s]
Training Epoch 3/3:  68%|######8   | 43/63 [00:26<00:12,  1.63it/s]
Training Epoch 3/3:  70%|######9   | 44/63 [00:26<00:11,  1.62it/s]
Training Epoch 3/3:  71%|#######1  | 45/63 [00:27<00:11,  1.62it/s]
Training Epoch 3/3:  73%|#######3  | 46/63 [00:28<00:10,  1.63it/s]
Training Epoch 3/3:  75%|#######4  | 47/63 [00:28<00:09,  1.63it/s]
Training Epoch 3/3:  76%|#######6  | 48/63 [00:29<00:09,  1.63it/s]
Training Epoch 3/3:  78%|#######7  | 49/63 [00:29<00:08,  1.63it/s]
Training Epoch 3/3:  79%|#######9  | 50/63 [00:30<00:07,  1.63it/s]
Training Epoch 3/3:  81%|########  | 51/63 [00:31<00:07,  1.64it/s]
Training Epoch 3/3:  83%|########2 | 52/63 [00:31<00:06,  1.64it/s]
Training Epoch 3/3:  84%|########4 | 53/63 [00:32<00:06,  1.64it/s]
Training Epoch 3/3:  86%|########5 | 54/63 [00:32<00:05,  1.64it/s]
Training Epoch 3/3:  87%|########7 | 55/63 [00:33<00:04,  1.64it/s]
Training Epoch 3/3:  89%|########8 | 56/63 [00:34<00:04,  1.64it/s]
Training Epoch 3/3:  90%|######### | 57/63 [00:34<00:03,  1.64it/s]
Training Epoch 3/3:  92%|#########2| 58/63 [00:35<00:03,  1.63it/s]
Training Epoch 3/3:  94%|#########3| 59/63 [00:36<00:02,  1.64it/s]
Training Epoch 3/3:  95%|#########5| 60/63 [00:36<00:01,  1.63it/s]
Training Epoch 3/3:  97%|#########6| 61/63 [00:37<00:01,  1.63it/s]
Training Epoch 3/3:  98%|#########8| 62/63 [00:37<00:00,  1.62it/s]
Training Epoch 3/3: 100%|##########| 63/63 [00:38<00:00,  1.84it/s]
Training Epoch 3/3: 100%|##########| 63/63 [00:38<00:00,  1.65it/s]

Validating Epoch 3/3:   0%|          | 0/16 [00:00<?, ?it/s]
Validating Epoch 3/3:   6%|6         | 1/16 [00:00<00:02,  5.13it/s]
Validating Epoch 3/3:  12%|#2        | 2/16 [00:00<00:02,  5.01it/s]
Validating Epoch 3/3:  19%|#8        | 3/16 [00:00<00:02,  4.88it/s]
Validating Epoch 3/3:  25%|##5       | 4/16 [00:00<00:02,  4.98it/s]
Validating Epoch 3/3:  31%|###1      | 5/16 [00:01<00:02,  4.93it/s]
Validating Epoch 3/3:  38%|###7      | 6/16 [00:01<00:02,  4.84it/s]
Validating Epoch 3/3:  44%|####3     | 7/16 [00:01<00:01,  4.94it/s]
Validating Epoch 3/3:  50%|#####     | 8/16 [00:01<00:01,  4.88it/s]
Validating Epoch 3/3:  56%|#####6    | 9/16 [00:01<00:01,  4.94it/s]
Validating Epoch 3/3:  62%|######2   | 10/16 [00:02<00:01,  4.87it/s]
Validating Epoch 3/3:  69%|######8   | 11/16 [00:02<00:01,  4.82it/s]
Validating Epoch 3/3:  75%|#######5  | 12/16 [00:02<00:00,  4.89it/s]
Validating Epoch 3/3:  81%|########1 | 13/16 [00:02<00:00,  4.91it/s]
Validating Epoch 3/3:  88%|########7 | 14/16 [00:02<00:00,  4.97it/s]
Validating Epoch 3/3:  94%|#########3| 15/16 [00:03<00:00,  4.83it/s]
Validating Epoch 3/3: 100%|##########| 16/16 [00:03<00:00,  5.44it/s]
Validating Epoch 3/3: 100%|##########| 16/16 [00:03<00:00,  5.01it/s]
Epoch 3:
Train Loss: 0.1697
Val Loss: 0.4137
Val Accuracy: 83.60%

Losses decrease almost in unison and the validation accuracy improves, indicating that the model is learning useful patterns.

Once the model has been trained, we can eyeball results.

predict(model, "this is a hell of a movie", tokenizer)
1
predict(model, "this movie is hell", tokenizer)
0

And do more vigorous evaluation on the held-out test set:

df = pd.DataFrame()
df['text'] = X_test 
df['label'] = y_test
df['prediction'] = [predict(model, text, tokenizer) for text in X_test]

metrics = {
    'Accuracy': accuracy_score(df['label'], df['prediction']),
    'Precision': precision_score(df['label'], df['prediction'], average='weighted'),
    'Recall': recall_score(df['label'], df['prediction'], average='weighted'),
    'F1 Score': f1_score(df['label'], df['prediction'], average='weighted')
}

print("\nEvaluation Metrics (training set size: 1250):")

Evaluation Metrics (training set size: 1250):
for metric, value in metrics.items():
    print(f"{metric}: {value:.3f}")
Accuracy: 0.860
Precision: 0.861
Recall: 0.860
F1 Score: 0.860
# or take a shortcut
print(classification_report(df['label'], df['prediction'], target_names=['negative', 'positive']))
              precision    recall  f1-score   support

    negative       0.85      0.88      0.86        50
    positive       0.88      0.84      0.86        50

    accuracy                           0.86       100
   macro avg       0.86      0.86      0.86       100
weighted avg       0.86      0.86      0.86       100

Once we’re happy with this, we can save the model for later use:

torch.save(model, 'bert_sentiment_model_full.pth')

# Loading 
model = torch.load('bert_sentiment_model_full.pth')
<string>:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
model.eval()
BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)