Chapter 15: NLI using BERT

Natural Language Inference (NLI), sometimes called textual entailment, represents a fundamental task in Natural Language Processing. The goal is to determine the logical relationship between pairs of sentences – a premise and a hypothesis. Given these two pieces of text, models must decide whether the hypothesis is logically entailed by the premise, contradicts it, or is unrelated (neutral).

For example, if our premise is “The cat is sleeping on the couch” and our hypothesis is “The cat is resting”, this would be considered entailment since sleeping logically implies resting. However, if the hypothesis were “The cat is playing,” this would be a contradiction. A hypothesis like “The cat has brown fur” would be neutral – we cannot determine its truth value from the premise alone.

This task serves as an important benchmark for evaluating language understanding in AI systems. Unlike simpler tasks that can be solved through pattern matching or keyword detection, NLI requires deeper semantic comprehension. A model must understand context, implications, and common sense relationships between concepts.

Pre-trained language models like BERT have shown remarkable capabilities in NLI tasks. Through their pre-training on vast amounts of text, they develop representations that capture subtle semantic relationships. We can fine-tune these models on NLI datasets to create powerful classifiers for textual relationships. Laurer et al. (2024) have suggested using NLI for classification tasks all together, as they might outperform fine-tuned classifiers for classification tasks.

In the following sections, we will implement an NLI system using BERT. We’ll see how to prepare sentence pairs for the model, create an appropriate architecture for the task, and train the system to recognize these logical relationships. The same principles we explored in simpler text classification tasks will apply here, but with additional complexity in both the data preparation and model architecture. I aimed to prepare everything in such a way that you can copy-paste the code. In your own scripts, you will have to define the same classes and your annotated data should follow the same structure – i.e., have two list columns named “premise” and “hypothesis” and another column named “label.” The label column should be strings containing “contradiction,” “neutral,” or “entailment”. I highly suggest going with string labels to avoid confusion about the predictions.

Set up Python and classes

In the first chunk, we set up our Python–R connection through the reticulate package and specify our virtual environment that we have set up before. It is important that it contains all the relevant packages. For reference on how to set up the venv, check chapter 14 on BERT.

needs(reticulate)

use_virtualenv("_pyenv/transformer_env")

Then, we import all necessary libraries and set up our device configuration. The device setup is particularly important as it allows our code to run efficiently on different hardware configurations – whether that’s a Silicon Mac using MPS (in my case), a machine with CUDA-enabled GPU, or a regular CPU. Depending on this, you might have to install the respective torch packages. We use pandas for data manipulation, torch (PyTorch) for deep learning operations, and the transformers library for access to pre-trained BERT models. Furthermore, we use an array of sklearn (scikit-learn) functions for train-test split creation and subsequent model evaluation.

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm.auto import tqdm

# First set up the device
def get_device():
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device

device = get_device()
print(f"Using device: {device}")
Using device: mps

Next, we define our custom NLIDataset class that handles the specific requirements of Natural Language Inference data. Classes in Python provide a way to bundle data and functionality together. They act as blueprints for creating objects that share similar characteristics and behaviors. Think of a class like a template - it defines what type of data the object can store (through attributes) and what operations it can perform (through methods).

Here, several things happen: first, Python creates a new object based on the NLIDataset template we define, the __init__ method is used for constructing the set. Then, the object gets its own copy of the attributes. Finally, we can call methods on this object – e.g., dataset.get_text(0). Think of it as preprocessing.

This class processes pairs of sentences (premises and hypotheses) along with their labels. It includes robust error handling for different label formats and converts all inputs into the tensor format required by PyTorch. The max_length parameter controls sequence length, with longer sequences being truncated and shorter ones padded.

# Define the Dataset class
class NLIDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.premises = df['premise'].tolist()
        self.hypotheses = df['hypothesis'].tolist()
        self.max_length = max_length
        
        try:
            # Try to convert to numeric and verify range
            numeric_labels = pd.to_numeric(df['label'])
            if not all(numeric_labels.between(0, 2)):
                raise ValueError("Numeric labels must be in range [0,2]")
            self.labels = numeric_labels.tolist()
        except ValueError:
            # If conversion to numeric fails, treat as string labels
            self.label_map = {'contradiction': 0, 'neutral': 1, 'entailment': 2}
            try:
                self.labels = [self.label_map[label] for label in df['label']]
            except KeyError as e:
                raise ValueError(f"Invalid label found. Labels must be either numbers [0,2] or one of {list(self.label_map.keys())}")

    def __len__(self):
        return len(self.premises)

    def __getitem__(self, idx):
        premise = str(self.premises[idx])
        hypothesis = str(self.hypotheses[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            premise,
            hypothesis,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'token_type_ids': encoding['token_type_ids'].squeeze(),
            'labels': torch.tensor(label)
        }

Then we can create the prepare_nli_data function streamlines the data preparation process. It splits the data into training and validation sets (validation sets are used during training for intermediary evaluation during training, to avoid overfitting), creates Dataset objects, and wraps them in DataLoader instances. The DataLoaders handle batching and shuffling during training, with customizable batch size and test split ratio. This serves as an organized data pipeline to make sure we get efficient training and validation processes.

# Define the data preparation function
def prepare_nli_data(df, tokenizer, test_size=0.2, batch_size=16):
    train_df, val_df = train_test_split(df, test_size=test_size, random_state=42)
    
    train_dataset = NLIDataset(train_df, tokenizer)
    val_dataset = NLIDataset(val_df, tokenizer)
    
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True
    )
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=batch_size
    )
    
    return train_dataloader, val_dataloader

The BertForNLI class defines our model architecture for Natural Language Inference. It builds upon the pre-trained BERT model, adding a dropout layer for regularization, and finally a classification head for our three-class prediction task. The forward method handles the flow of data through the model, utilizing BERT’s special tokens and attention mechanisms for processing sentence pairs.

# Define the model class
class BertForNLI(nn.Module):
    def __init__(self, num_labels=3):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_labels)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        return logits

In the training function chunk, we implement the complete training loop with progress tracking, wrapped in the function train_nli_model. This function handles the optimization process using AdamW, monitors loss during training, and evaluates performance on the validation set after each epoch. The progress bar provides real-time feedback during training, making it easier to monitor long training runs.

# Define the training function
def train_nli_model(model, train_dataloader, val_dataloader, epochs=3):
    optimizer = AdamW(model.parameters(), lr=2e-5)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}')
        
        for batch in progress_bar:
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
            
            loss = F.cross_entropy(outputs, labels)
            total_loss += loss.item()
            
            loss.backward()
            optimizer.step()
            
            progress_bar.set_postfix({'loss': loss.item()})
        
        val_accuracy = evaluate_nli_model(model, val_dataloader)
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_dataloader):.4f}, Val Accuracy: {val_accuracy:.4f}")

Finally, the evaluation chunk contains two crucial functions. The evaluate_nli_model function calculates model accuracy on a given data set, while predict_nli handles predictions for new sentence pairs. These functions include proper model state management (evaluation mode) and device handling, ensuring consistent results whether running on CPU or GPU.

def evaluate_nli_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total

# predict label for new examples
def predict_nli(model, tokenizer, premise, hypothesis):
    model.eval()
    encoding = tokenizer(
        premise,
        hypothesis,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    token_type_ids = encoding['token_type_ids'].to(device)
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        prediction = torch.argmax(outputs, dim=1)
    
    label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
    return label_map[prediction.item()]

Training and Evaluating on our own Data

Finally, we have set everything up so that we can work on our own data. First, we read in and prepare prepare our actual data, a data set I downloaded from Kaggle, i.e., reading from a TSV file and replace the numeric labels with strings to avoid confusion. We create training and test sets from the data. This structured approach to data organization is crucial for maintaining consistency throughout the training and evaluation process.

df = pd.read_csv('./files/pair-class_dev.tsv', sep='\t')
label_map = {
    0: 'entailment',
    1: 'neutral',
    2: 'contradiction'
}
df['label'] = df['label'].replace(label_map)
train = df[0:2000]
test = df[2000:3000].copy()

Then we create our tokenizer and model instances. The tokenizer is loaded from the pre-trained BERT model, and our custom NLI model is moved to the appropriate device (i.e., GPU, CPU, or, in my case, MPS) for computation. This setup forms the foundation for our training process.

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNLI()
model = model.to(device)

The training execution chunk brings everything together, creating our data loaders and running the training process. We specify two epochs for training, though this parameter could be adjusted based on the specific requirements of the task and the observed learning dynamics (more is generally better).

train_dataloader, val_dataloader = prepare_nli_data(train, tokenizer)
train_nli_model(model, train_dataloader, val_dataloader, epochs = 2)

Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s]
Epoch 1:   0%|          | 0/100 [00:01<?, ?it/s, loss=1.19]
Epoch 1:   1%|1         | 1/100 [00:01<02:09,  1.31s/it, loss=1.19]
Epoch 1:   1%|1         | 1/100 [00:02<02:09,  1.31s/it, loss=1.14]
Epoch 1:   2%|2         | 2/100 [00:02<01:33,  1.04it/s, loss=1.14]
Epoch 1:   2%|2         | 2/100 [00:02<01:33,  1.04it/s, loss=1.09]
Epoch 1:   3%|3         | 3/100 [00:02<01:12,  1.34it/s, loss=1.09]
Epoch 1:   3%|3         | 3/100 [00:03<01:12,  1.34it/s, loss=1.15]
Epoch 1:   4%|4         | 4/100 [00:03<01:02,  1.54it/s, loss=1.15]
Epoch 1:   4%|4         | 4/100 [00:03<01:02,  1.54it/s, loss=1.11]
Epoch 1:   5%|5         | 5/100 [00:03<00:56,  1.68it/s, loss=1.11]
Epoch 1:   5%|5         | 5/100 [00:04<00:56,  1.68it/s, loss=1.11]
Epoch 1:   6%|6         | 6/100 [00:04<00:52,  1.78it/s, loss=1.11]
Epoch 1:   6%|6         | 6/100 [00:04<00:52,  1.78it/s, loss=1.13]
Epoch 1:   7%|7         | 7/100 [00:04<00:50,  1.84it/s, loss=1.13]
Epoch 1:   7%|7         | 7/100 [00:05<00:50,  1.84it/s, loss=1.02]
Epoch 1:   8%|8         | 8/100 [00:05<00:48,  1.88it/s, loss=1.02]
Epoch 1:   8%|8         | 8/100 [00:05<00:48,  1.88it/s, loss=1.01]
Epoch 1:   9%|9         | 9/100 [00:05<00:47,  1.92it/s, loss=1.01]
Epoch 1:   9%|9         | 9/100 [00:06<00:47,  1.92it/s, loss=1.01]
Epoch 1:  10%|#         | 10/100 [00:06<00:46,  1.94it/s, loss=1.01]
Epoch 1:  10%|#         | 10/100 [00:06<00:46,  1.94it/s, loss=1.2] 
Epoch 1:  11%|#1        | 11/100 [00:06<00:45,  1.96it/s, loss=1.2]
Epoch 1:  11%|#1        | 11/100 [00:07<00:45,  1.96it/s, loss=1.01]
Epoch 1:  12%|#2        | 12/100 [00:07<00:44,  1.97it/s, loss=1.01]
Epoch 1:  12%|#2        | 12/100 [00:07<00:44,  1.97it/s, loss=1.11]
Epoch 1:  13%|#3        | 13/100 [00:07<00:44,  1.97it/s, loss=1.11]
Epoch 1:  13%|#3        | 13/100 [00:08<00:44,  1.97it/s, loss=1.11]
Epoch 1:  14%|#4        | 14/100 [00:08<00:43,  1.98it/s, loss=1.11]
Epoch 1:  14%|#4        | 14/100 [00:08<00:43,  1.98it/s, loss=0.995]
Epoch 1:  15%|#5        | 15/100 [00:08<00:42,  1.98it/s, loss=0.995]
Epoch 1:  15%|#5        | 15/100 [00:09<00:42,  1.98it/s, loss=1.11] 
Epoch 1:  16%|#6        | 16/100 [00:09<00:42,  1.98it/s, loss=1.11]
Epoch 1:  16%|#6        | 16/100 [00:09<00:42,  1.98it/s, loss=1.28]
Epoch 1:  17%|#7        | 17/100 [00:09<00:41,  1.99it/s, loss=1.28]
Epoch 1:  17%|#7        | 17/100 [00:10<00:41,  1.99it/s, loss=1.05]
Epoch 1:  18%|#8        | 18/100 [00:10<00:41,  1.99it/s, loss=1.05]
Epoch 1:  18%|#8        | 18/100 [00:10<00:41,  1.99it/s, loss=1.08]
Epoch 1:  19%|#9        | 19/100 [00:10<00:40,  1.99it/s, loss=1.08]
Epoch 1:  19%|#9        | 19/100 [00:11<00:40,  1.99it/s, loss=1.16]
Epoch 1:  20%|##        | 20/100 [00:11<00:40,  1.98it/s, loss=1.16]
Epoch 1:  20%|##        | 20/100 [00:11<00:40,  1.98it/s, loss=1.06]
Epoch 1:  21%|##1       | 21/100 [00:11<00:39,  1.99it/s, loss=1.06]
Epoch 1:  21%|##1       | 21/100 [00:12<00:39,  1.99it/s, loss=0.936]
Epoch 1:  22%|##2       | 22/100 [00:12<00:39,  1.99it/s, loss=0.936]
Epoch 1:  22%|##2       | 22/100 [00:12<00:39,  1.99it/s, loss=1.15] 
Epoch 1:  23%|##3       | 23/100 [00:12<00:38,  2.00it/s, loss=1.15]
Epoch 1:  23%|##3       | 23/100 [00:13<00:38,  2.00it/s, loss=1.12]
Epoch 1:  24%|##4       | 24/100 [00:13<00:37,  2.00it/s, loss=1.12]
Epoch 1:  24%|##4       | 24/100 [00:13<00:37,  2.00it/s, loss=1.08]
Epoch 1:  25%|##5       | 25/100 [00:13<00:37,  2.01it/s, loss=1.08]
Epoch 1:  25%|##5       | 25/100 [00:14<00:37,  2.01it/s, loss=1.05]
Epoch 1:  26%|##6       | 26/100 [00:14<00:36,  2.01it/s, loss=1.05]
Epoch 1:  26%|##6       | 26/100 [00:14<00:36,  2.01it/s, loss=1.07]
Epoch 1:  27%|##7       | 27/100 [00:14<00:36,  2.01it/s, loss=1.07]
Epoch 1:  27%|##7       | 27/100 [00:15<00:36,  2.01it/s, loss=1.11]
Epoch 1:  28%|##8       | 28/100 [00:15<00:35,  2.01it/s, loss=1.11]
Epoch 1:  28%|##8       | 28/100 [00:15<00:35,  2.01it/s, loss=0.974]
Epoch 1:  29%|##9       | 29/100 [00:15<00:35,  2.01it/s, loss=0.974]
Epoch 1:  29%|##9       | 29/100 [00:16<00:35,  2.01it/s, loss=1.07] 
Epoch 1:  30%|###       | 30/100 [00:16<00:34,  2.01it/s, loss=1.07]
Epoch 1:  30%|###       | 30/100 [00:16<00:34,  2.01it/s, loss=1.15]
Epoch 1:  31%|###1      | 31/100 [00:16<00:34,  2.01it/s, loss=1.15]
Epoch 1:  31%|###1      | 31/100 [00:17<00:34,  2.01it/s, loss=1.02]
Epoch 1:  32%|###2      | 32/100 [00:17<00:33,  2.01it/s, loss=1.02]
Epoch 1:  32%|###2      | 32/100 [00:17<00:33,  2.01it/s, loss=1.14]
Epoch 1:  33%|###3      | 33/100 [00:17<00:33,  2.02it/s, loss=1.14]
Epoch 1:  33%|###3      | 33/100 [00:18<00:33,  2.02it/s, loss=1.06]
Epoch 1:  34%|###4      | 34/100 [00:18<00:32,  2.02it/s, loss=1.06]
Epoch 1:  34%|###4      | 34/100 [00:18<00:32,  2.02it/s, loss=0.974]
Epoch 1:  35%|###5      | 35/100 [00:18<00:32,  2.02it/s, loss=0.974]
Epoch 1:  35%|###5      | 35/100 [00:19<00:32,  2.02it/s, loss=0.934]
Epoch 1:  36%|###6      | 36/100 [00:19<00:31,  2.01it/s, loss=0.934]
Epoch 1:  36%|###6      | 36/100 [00:19<00:31,  2.01it/s, loss=1.04] 
Epoch 1:  37%|###7      | 37/100 [00:19<00:31,  2.01it/s, loss=1.04]
Epoch 1:  37%|###7      | 37/100 [00:20<00:31,  2.01it/s, loss=1.23]
Epoch 1:  38%|###8      | 38/100 [00:20<00:30,  2.00it/s, loss=1.23]
Epoch 1:  38%|###8      | 38/100 [00:20<00:30,  2.00it/s, loss=1.07]
Epoch 1:  39%|###9      | 39/100 [00:20<00:30,  2.01it/s, loss=1.07]
Epoch 1:  39%|###9      | 39/100 [00:21<00:30,  2.01it/s, loss=1.03]
Epoch 1:  40%|####      | 40/100 [00:21<00:29,  2.00it/s, loss=1.03]
Epoch 1:  40%|####      | 40/100 [00:21<00:29,  2.00it/s, loss=1.09]
Epoch 1:  41%|####1     | 41/100 [00:21<00:29,  2.01it/s, loss=1.09]
Epoch 1:  41%|####1     | 41/100 [00:22<00:29,  2.01it/s, loss=1.05]
Epoch 1:  42%|####2     | 42/100 [00:22<00:28,  2.00it/s, loss=1.05]
Epoch 1:  42%|####2     | 42/100 [00:22<00:28,  2.00it/s, loss=1.09]
Epoch 1:  43%|####3     | 43/100 [00:22<00:28,  2.00it/s, loss=1.09]
Epoch 1:  43%|####3     | 43/100 [00:23<00:28,  2.00it/s, loss=0.994]
Epoch 1:  44%|####4     | 44/100 [00:23<00:28,  2.00it/s, loss=0.994]
Epoch 1:  44%|####4     | 44/100 [00:23<00:28,  2.00it/s, loss=0.979]
Epoch 1:  45%|####5     | 45/100 [00:23<00:27,  2.00it/s, loss=0.979]
Epoch 1:  45%|####5     | 45/100 [00:24<00:27,  2.00it/s, loss=0.965]
Epoch 1:  46%|####6     | 46/100 [00:24<00:26,  2.00it/s, loss=0.965]
Epoch 1:  46%|####6     | 46/100 [00:24<00:26,  2.00it/s, loss=1.11] 
Epoch 1:  47%|####6     | 47/100 [00:24<00:26,  2.00it/s, loss=1.11]
Epoch 1:  47%|####6     | 47/100 [00:25<00:26,  2.00it/s, loss=0.864]
Epoch 1:  48%|####8     | 48/100 [00:25<00:25,  2.01it/s, loss=0.864]
Epoch 1:  48%|####8     | 48/100 [00:25<00:25,  2.01it/s, loss=1.15] 
Epoch 1:  49%|####9     | 49/100 [00:25<00:25,  2.00it/s, loss=1.15]
Epoch 1:  49%|####9     | 49/100 [00:26<00:25,  2.00it/s, loss=1.16]
Epoch 1:  50%|#####     | 50/100 [00:26<00:25,  1.99it/s, loss=1.16]
Epoch 1:  50%|#####     | 50/100 [00:26<00:25,  1.99it/s, loss=1.07]
Epoch 1:  51%|#####1    | 51/100 [00:26<00:24,  2.00it/s, loss=1.07]
Epoch 1:  51%|#####1    | 51/100 [00:27<00:24,  2.00it/s, loss=1.02]
Epoch 1:  52%|#####2    | 52/100 [00:27<00:24,  2.00it/s, loss=1.02]
Epoch 1:  52%|#####2    | 52/100 [00:27<00:24,  2.00it/s, loss=1.16]
Epoch 1:  53%|#####3    | 53/100 [00:27<00:23,  2.00it/s, loss=1.16]
Epoch 1:  53%|#####3    | 53/100 [00:28<00:23,  2.00it/s, loss=0.952]
Epoch 1:  54%|#####4    | 54/100 [00:28<00:22,  2.00it/s, loss=0.952]
Epoch 1:  54%|#####4    | 54/100 [00:28<00:22,  2.00it/s, loss=0.919]
Epoch 1:  55%|#####5    | 55/100 [00:28<00:22,  2.00it/s, loss=0.919]
Epoch 1:  55%|#####5    | 55/100 [00:29<00:22,  2.00it/s, loss=1]    
Epoch 1:  56%|#####6    | 56/100 [00:29<00:22,  2.00it/s, loss=1]
Epoch 1:  56%|#####6    | 56/100 [00:29<00:22,  2.00it/s, loss=0.87]
Epoch 1:  57%|#####6    | 57/100 [00:29<00:21,  2.00it/s, loss=0.87]
Epoch 1:  57%|#####6    | 57/100 [00:30<00:21,  2.00it/s, loss=0.892]
Epoch 1:  58%|#####8    | 58/100 [00:30<00:21,  1.99it/s, loss=0.892]
Epoch 1:  58%|#####8    | 58/100 [00:30<00:21,  1.99it/s, loss=0.865]
Epoch 1:  59%|#####8    | 59/100 [00:30<00:20,  1.99it/s, loss=0.865]
Epoch 1:  59%|#####8    | 59/100 [00:31<00:20,  1.99it/s, loss=0.912]
Epoch 1:  60%|######    | 60/100 [00:31<00:20,  2.00it/s, loss=0.912]
Epoch 1:  60%|######    | 60/100 [00:31<00:20,  2.00it/s, loss=0.885]
Epoch 1:  61%|######1   | 61/100 [00:31<00:19,  2.00it/s, loss=0.885]
Epoch 1:  61%|######1   | 61/100 [00:32<00:19,  2.00it/s, loss=1.06] 
Epoch 1:  62%|######2   | 62/100 [00:32<00:19,  2.00it/s, loss=1.06]
Epoch 1:  62%|######2   | 62/100 [00:32<00:19,  2.00it/s, loss=0.885]
Epoch 1:  63%|######3   | 63/100 [00:32<00:18,  1.99it/s, loss=0.885]
Epoch 1:  63%|######3   | 63/100 [00:33<00:18,  1.99it/s, loss=0.889]
Epoch 1:  64%|######4   | 64/100 [00:33<00:18,  2.00it/s, loss=0.889]
Epoch 1:  64%|######4   | 64/100 [00:33<00:18,  2.00it/s, loss=0.997]
Epoch 1:  65%|######5   | 65/100 [00:33<00:17,  2.00it/s, loss=0.997]
Epoch 1:  65%|######5   | 65/100 [00:34<00:17,  2.00it/s, loss=0.906]
Epoch 1:  66%|######6   | 66/100 [00:34<00:16,  2.00it/s, loss=0.906]
Epoch 1:  66%|######6   | 66/100 [00:34<00:16,  2.00it/s, loss=0.923]
Epoch 1:  67%|######7   | 67/100 [00:34<00:16,  2.00it/s, loss=0.923]
Epoch 1:  67%|######7   | 67/100 [00:35<00:16,  2.00it/s, loss=1.02] 
Epoch 1:  68%|######8   | 68/100 [00:35<00:16,  2.00it/s, loss=1.02]
Epoch 1:  68%|######8   | 68/100 [00:35<00:16,  2.00it/s, loss=0.876]
Epoch 1:  69%|######9   | 69/100 [00:35<00:15,  2.00it/s, loss=0.876]
Epoch 1:  69%|######9   | 69/100 [00:36<00:15,  2.00it/s, loss=0.885]
Epoch 1:  70%|#######   | 70/100 [00:36<00:15,  1.99it/s, loss=0.885]
Epoch 1:  70%|#######   | 70/100 [00:36<00:15,  1.99it/s, loss=1.01] 
Epoch 1:  71%|#######1  | 71/100 [00:36<00:14,  2.00it/s, loss=1.01]
Epoch 1:  71%|#######1  | 71/100 [00:37<00:14,  2.00it/s, loss=0.935]
Epoch 1:  72%|#######2  | 72/100 [00:37<00:14,  2.00it/s, loss=0.935]
Epoch 1:  72%|#######2  | 72/100 [00:37<00:14,  2.00it/s, loss=0.974]
Epoch 1:  73%|#######3  | 73/100 [00:37<00:13,  2.00it/s, loss=0.974]
Epoch 1:  73%|#######3  | 73/100 [00:38<00:13,  2.00it/s, loss=0.849]
Epoch 1:  74%|#######4  | 74/100 [00:38<00:12,  2.00it/s, loss=0.849]
Epoch 1:  74%|#######4  | 74/100 [00:38<00:12,  2.00it/s, loss=0.938]
Epoch 1:  75%|#######5  | 75/100 [00:38<00:12,  2.00it/s, loss=0.938]
Epoch 1:  75%|#######5  | 75/100 [00:39<00:12,  2.00it/s, loss=0.835]
Epoch 1:  76%|#######6  | 76/100 [00:39<00:11,  2.01it/s, loss=0.835]
Epoch 1:  76%|#######6  | 76/100 [00:39<00:11,  2.01it/s, loss=1.14] 
Epoch 1:  77%|#######7  | 77/100 [00:39<00:11,  2.01it/s, loss=1.14]
Epoch 1:  77%|#######7  | 77/100 [00:40<00:11,  2.01it/s, loss=0.964]
Epoch 1:  78%|#######8  | 78/100 [00:40<00:10,  2.00it/s, loss=0.964]
Epoch 1:  78%|#######8  | 78/100 [00:40<00:10,  2.00it/s, loss=0.765]
Epoch 1:  79%|#######9  | 79/100 [00:40<00:10,  2.01it/s, loss=0.765]
Epoch 1:  79%|#######9  | 79/100 [00:41<00:10,  2.01it/s, loss=0.906]
Epoch 1:  80%|########  | 80/100 [00:41<00:09,  2.01it/s, loss=0.906]
Epoch 1:  80%|########  | 80/100 [00:41<00:09,  2.01it/s, loss=0.776]
Epoch 1:  81%|########1 | 81/100 [00:41<00:09,  2.00it/s, loss=0.776]
Epoch 1:  81%|########1 | 81/100 [00:42<00:09,  2.00it/s, loss=0.866]
Epoch 1:  82%|########2 | 82/100 [00:42<00:09,  2.00it/s, loss=0.866]
Epoch 1:  82%|########2 | 82/100 [00:42<00:09,  2.00it/s, loss=0.848]
Epoch 1:  83%|########2 | 83/100 [00:42<00:08,  2.00it/s, loss=0.848]
Epoch 1:  83%|########2 | 83/100 [00:43<00:08,  2.00it/s, loss=0.968]
Epoch 1:  84%|########4 | 84/100 [00:43<00:08,  2.00it/s, loss=0.968]
Epoch 1:  84%|########4 | 84/100 [00:43<00:08,  2.00it/s, loss=0.923]
Epoch 1:  85%|########5 | 85/100 [00:43<00:07,  2.00it/s, loss=0.923]
Epoch 1:  85%|########5 | 85/100 [00:44<00:07,  2.00it/s, loss=0.926]
Epoch 1:  86%|########6 | 86/100 [00:44<00:06,  2.00it/s, loss=0.926]
Epoch 1:  86%|########6 | 86/100 [00:44<00:06,  2.00it/s, loss=0.951]
Epoch 1:  87%|########7 | 87/100 [00:44<00:06,  2.00it/s, loss=0.951]
Epoch 1:  87%|########7 | 87/100 [00:45<00:06,  2.00it/s, loss=0.768]
Epoch 1:  88%|########8 | 88/100 [00:45<00:06,  2.00it/s, loss=0.768]
Epoch 1:  88%|########8 | 88/100 [00:45<00:06,  2.00it/s, loss=0.939]
Epoch 1:  89%|########9 | 89/100 [00:45<00:05,  1.99it/s, loss=0.939]
Epoch 1:  89%|########9 | 89/100 [00:46<00:05,  1.99it/s, loss=0.702]
Epoch 1:  90%|######### | 90/100 [00:46<00:05,  2.00it/s, loss=0.702]
Epoch 1:  90%|######### | 90/100 [00:46<00:05,  2.00it/s, loss=0.884]
Epoch 1:  91%|#########1| 91/100 [00:46<00:04,  2.00it/s, loss=0.884]
Epoch 1:  91%|#########1| 91/100 [00:47<00:04,  2.00it/s, loss=0.83] 
Epoch 1:  92%|#########2| 92/100 [00:47<00:04,  2.00it/s, loss=0.83]
Epoch 1:  92%|#########2| 92/100 [00:47<00:04,  2.00it/s, loss=0.753]
Epoch 1:  93%|#########3| 93/100 [00:47<00:03,  2.00it/s, loss=0.753]
Epoch 1:  93%|#########3| 93/100 [00:48<00:03,  2.00it/s, loss=0.819]
Epoch 1:  94%|#########3| 94/100 [00:48<00:03,  2.00it/s, loss=0.819]
Epoch 1:  94%|#########3| 94/100 [00:48<00:03,  2.00it/s, loss=1.07] 
Epoch 1:  95%|#########5| 95/100 [00:48<00:02,  2.01it/s, loss=1.07]
Epoch 1:  95%|#########5| 95/100 [00:49<00:02,  2.01it/s, loss=0.794]
Epoch 1:  96%|#########6| 96/100 [00:49<00:01,  2.00it/s, loss=0.794]
Epoch 1:  96%|#########6| 96/100 [00:49<00:01,  2.00it/s, loss=0.879]
Epoch 1:  97%|#########7| 97/100 [00:49<00:01,  2.00it/s, loss=0.879]
Epoch 1:  97%|#########7| 97/100 [00:50<00:01,  2.00it/s, loss=1.02] 
Epoch 1:  98%|#########8| 98/100 [00:50<00:00,  2.01it/s, loss=1.02]
Epoch 1:  98%|#########8| 98/100 [00:50<00:00,  2.01it/s, loss=1.13]
Epoch 1:  99%|#########9| 99/100 [00:50<00:00,  2.01it/s, loss=1.13]
Epoch 1:  99%|#########9| 99/100 [00:51<00:00,  2.01it/s, loss=0.659]
Epoch 1: 100%|##########| 100/100 [00:51<00:00,  1.99it/s, loss=0.659]
Epoch 1: 100%|##########| 100/100 [00:51<00:00,  1.96it/s, loss=0.659]
Epoch 1, Loss: 0.9956, Val Accuracy: 0.5800

Epoch 2:   0%|          | 0/100 [00:00<?, ?it/s]
Epoch 2:   0%|          | 0/100 [00:00<?, ?it/s, loss=0.545]
Epoch 2:   1%|1         | 1/100 [00:00<00:51,  1.94it/s, loss=0.545]
Epoch 2:   1%|1         | 1/100 [00:01<00:51,  1.94it/s, loss=0.955]
Epoch 2:   2%|2         | 2/100 [00:01<00:50,  1.96it/s, loss=0.955]
Epoch 2:   2%|2         | 2/100 [00:01<00:50,  1.96it/s, loss=0.941]
Epoch 2:   3%|3         | 3/100 [00:01<00:48,  1.98it/s, loss=0.941]
Epoch 2:   3%|3         | 3/100 [00:02<00:48,  1.98it/s, loss=0.756]
Epoch 2:   4%|4         | 4/100 [00:02<00:48,  1.99it/s, loss=0.756]
Epoch 2:   4%|4         | 4/100 [00:02<00:48,  1.99it/s, loss=0.818]
Epoch 2:   5%|5         | 5/100 [00:02<00:47,  1.99it/s, loss=0.818]
Epoch 2:   5%|5         | 5/100 [00:03<00:47,  1.99it/s, loss=1.19] 
Epoch 2:   6%|6         | 6/100 [00:03<00:47,  2.00it/s, loss=1.19]
Epoch 2:   6%|6         | 6/100 [00:03<00:47,  2.00it/s, loss=0.646]
Epoch 2:   7%|7         | 7/100 [00:03<00:46,  1.99it/s, loss=0.646]
Epoch 2:   7%|7         | 7/100 [00:04<00:46,  1.99it/s, loss=0.747]
Epoch 2:   8%|8         | 8/100 [00:04<00:46,  1.99it/s, loss=0.747]
Epoch 2:   8%|8         | 8/100 [00:04<00:46,  1.99it/s, loss=0.809]
Epoch 2:   9%|9         | 9/100 [00:04<00:45,  1.99it/s, loss=0.809]
Epoch 2:   9%|9         | 9/100 [00:05<00:45,  1.99it/s, loss=0.638]
Epoch 2:  10%|#         | 10/100 [00:05<00:45,  1.99it/s, loss=0.638]
Epoch 2:  10%|#         | 10/100 [00:05<00:45,  1.99it/s, loss=0.731]
Epoch 2:  11%|#1        | 11/100 [00:05<00:44,  1.98it/s, loss=0.731]
Epoch 2:  11%|#1        | 11/100 [00:06<00:44,  1.98it/s, loss=0.731]
Epoch 2:  12%|#2        | 12/100 [00:06<00:44,  1.97it/s, loss=0.731]
Epoch 2:  12%|#2        | 12/100 [00:06<00:44,  1.97it/s, loss=0.865]
Epoch 2:  13%|#3        | 13/100 [00:06<00:44,  1.98it/s, loss=0.865]
Epoch 2:  13%|#3        | 13/100 [00:07<00:44,  1.98it/s, loss=0.81] 
Epoch 2:  14%|#4        | 14/100 [00:07<00:43,  1.98it/s, loss=0.81]
Epoch 2:  14%|#4        | 14/100 [00:07<00:43,  1.98it/s, loss=0.832]
Epoch 2:  15%|#5        | 15/100 [00:07<00:42,  1.98it/s, loss=0.832]
Epoch 2:  15%|#5        | 15/100 [00:08<00:42,  1.98it/s, loss=0.715]
Epoch 2:  16%|#6        | 16/100 [00:08<00:42,  1.99it/s, loss=0.715]
Epoch 2:  16%|#6        | 16/100 [00:08<00:42,  1.99it/s, loss=0.669]
Epoch 2:  17%|#7        | 17/100 [00:08<00:41,  2.00it/s, loss=0.669]
Epoch 2:  17%|#7        | 17/100 [00:09<00:41,  2.00it/s, loss=0.748]
Epoch 2:  18%|#8        | 18/100 [00:09<00:41,  2.00it/s, loss=0.748]
Epoch 2:  18%|#8        | 18/100 [00:09<00:41,  2.00it/s, loss=0.756]
Epoch 2:  19%|#9        | 19/100 [00:09<00:40,  1.99it/s, loss=0.756]
Epoch 2:  19%|#9        | 19/100 [00:10<00:40,  1.99it/s, loss=0.855]
Epoch 2:  20%|##        | 20/100 [00:10<00:40,  2.00it/s, loss=0.855]
Epoch 2:  20%|##        | 20/100 [00:10<00:40,  2.00it/s, loss=0.624]
Epoch 2:  21%|##1       | 21/100 [00:10<00:39,  1.99it/s, loss=0.624]
Epoch 2:  21%|##1       | 21/100 [00:11<00:39,  1.99it/s, loss=0.707]
Epoch 2:  22%|##2       | 22/100 [00:11<00:39,  1.99it/s, loss=0.707]
Epoch 2:  22%|##2       | 22/100 [00:11<00:39,  1.99it/s, loss=1.03] 
Epoch 2:  23%|##3       | 23/100 [00:11<00:38,  1.99it/s, loss=1.03]
Epoch 2:  23%|##3       | 23/100 [00:12<00:38,  1.99it/s, loss=0.973]
Epoch 2:  24%|##4       | 24/100 [00:12<00:38,  1.99it/s, loss=0.973]
Epoch 2:  24%|##4       | 24/100 [00:12<00:38,  1.99it/s, loss=0.618]
Epoch 2:  25%|##5       | 25/100 [00:12<00:37,  1.99it/s, loss=0.618]
Epoch 2:  25%|##5       | 25/100 [00:13<00:37,  1.99it/s, loss=0.783]
Epoch 2:  26%|##6       | 26/100 [00:13<00:37,  1.99it/s, loss=0.783]
Epoch 2:  26%|##6       | 26/100 [00:13<00:37,  1.99it/s, loss=0.596]
Epoch 2:  27%|##7       | 27/100 [00:13<00:36,  1.99it/s, loss=0.596]
Epoch 2:  27%|##7       | 27/100 [00:14<00:36,  1.99it/s, loss=0.919]
Epoch 2:  28%|##8       | 28/100 [00:14<00:36,  2.00it/s, loss=0.919]
Epoch 2:  28%|##8       | 28/100 [00:14<00:36,  2.00it/s, loss=0.758]
Epoch 2:  29%|##9       | 29/100 [00:14<00:35,  2.00it/s, loss=0.758]
Epoch 2:  29%|##9       | 29/100 [00:15<00:35,  2.00it/s, loss=0.537]
Epoch 2:  30%|###       | 30/100 [00:15<00:34,  2.00it/s, loss=0.537]
Epoch 2:  30%|###       | 30/100 [00:15<00:34,  2.00it/s, loss=0.723]
Epoch 2:  31%|###1      | 31/100 [00:15<00:34,  2.00it/s, loss=0.723]
Epoch 2:  31%|###1      | 31/100 [00:16<00:34,  2.00it/s, loss=0.583]
Epoch 2:  32%|###2      | 32/100 [00:16<00:33,  2.00it/s, loss=0.583]
Epoch 2:  32%|###2      | 32/100 [00:16<00:33,  2.00it/s, loss=0.692]
Epoch 2:  33%|###3      | 33/100 [00:16<00:33,  2.00it/s, loss=0.692]
Epoch 2:  33%|###3      | 33/100 [00:17<00:33,  2.00it/s, loss=0.937]
Epoch 2:  34%|###4      | 34/100 [00:17<00:33,  2.00it/s, loss=0.937]
Epoch 2:  34%|###4      | 34/100 [00:17<00:33,  2.00it/s, loss=0.743]
Epoch 2:  35%|###5      | 35/100 [00:17<00:32,  1.99it/s, loss=0.743]
Epoch 2:  35%|###5      | 35/100 [00:18<00:32,  1.99it/s, loss=0.822]
Epoch 2:  36%|###6      | 36/100 [00:18<00:32,  2.00it/s, loss=0.822]
Epoch 2:  36%|###6      | 36/100 [00:18<00:32,  2.00it/s, loss=0.655]
Epoch 2:  37%|###7      | 37/100 [00:18<00:31,  1.99it/s, loss=0.655]
Epoch 2:  37%|###7      | 37/100 [00:19<00:31,  1.99it/s, loss=0.735]
Epoch 2:  38%|###8      | 38/100 [00:19<00:31,  2.00it/s, loss=0.735]
Epoch 2:  38%|###8      | 38/100 [00:19<00:31,  2.00it/s, loss=0.805]
Epoch 2:  39%|###9      | 39/100 [00:19<00:30,  1.99it/s, loss=0.805]
Epoch 2:  39%|###9      | 39/100 [00:20<00:30,  1.99it/s, loss=0.795]
Epoch 2:  40%|####      | 40/100 [00:20<00:30,  2.00it/s, loss=0.795]
Epoch 2:  40%|####      | 40/100 [00:20<00:30,  2.00it/s, loss=0.663]
Epoch 2:  41%|####1     | 41/100 [00:20<00:29,  2.00it/s, loss=0.663]
Epoch 2:  41%|####1     | 41/100 [00:21<00:29,  2.00it/s, loss=0.691]
Epoch 2:  42%|####2     | 42/100 [00:21<00:29,  2.00it/s, loss=0.691]
Epoch 2:  42%|####2     | 42/100 [00:21<00:29,  2.00it/s, loss=0.516]
Epoch 2:  43%|####3     | 43/100 [00:21<00:28,  1.99it/s, loss=0.516]
Epoch 2:  43%|####3     | 43/100 [00:22<00:28,  1.99it/s, loss=0.759]
Epoch 2:  44%|####4     | 44/100 [00:22<00:28,  2.00it/s, loss=0.759]
Epoch 2:  44%|####4     | 44/100 [00:22<00:28,  2.00it/s, loss=0.641]
Epoch 2:  45%|####5     | 45/100 [00:22<00:27,  1.99it/s, loss=0.641]
Epoch 2:  45%|####5     | 45/100 [00:23<00:27,  1.99it/s, loss=0.711]
Epoch 2:  46%|####6     | 46/100 [00:23<00:27,  1.98it/s, loss=0.711]
Epoch 2:  46%|####6     | 46/100 [00:23<00:27,  1.98it/s, loss=0.513]
Epoch 2:  47%|####6     | 47/100 [00:23<00:26,  1.98it/s, loss=0.513]
Epoch 2:  47%|####6     | 47/100 [00:24<00:26,  1.98it/s, loss=0.552]
Epoch 2:  48%|####8     | 48/100 [00:24<00:26,  1.99it/s, loss=0.552]
Epoch 2:  48%|####8     | 48/100 [00:24<00:26,  1.99it/s, loss=1.02] 
Epoch 2:  49%|####9     | 49/100 [00:24<00:25,  1.99it/s, loss=1.02]
Epoch 2:  49%|####9     | 49/100 [00:25<00:25,  1.99it/s, loss=0.525]
Epoch 2:  50%|#####     | 50/100 [00:25<00:25,  1.99it/s, loss=0.525]
Epoch 2:  50%|#####     | 50/100 [00:25<00:25,  1.99it/s, loss=0.696]
Epoch 2:  51%|#####1    | 51/100 [00:25<00:24,  1.99it/s, loss=0.696]
Epoch 2:  51%|#####1    | 51/100 [00:26<00:24,  1.99it/s, loss=0.589]
Epoch 2:  52%|#####2    | 52/100 [00:26<00:24,  1.99it/s, loss=0.589]
Epoch 2:  52%|#####2    | 52/100 [00:26<00:24,  1.99it/s, loss=0.57] 
Epoch 2:  53%|#####3    | 53/100 [00:26<00:23,  2.00it/s, loss=0.57]
Epoch 2:  53%|#####3    | 53/100 [00:27<00:23,  2.00it/s, loss=1.04]
Epoch 2:  54%|#####4    | 54/100 [00:27<00:23,  2.00it/s, loss=1.04]
Epoch 2:  54%|#####4    | 54/100 [00:27<00:23,  2.00it/s, loss=0.727]
Epoch 2:  55%|#####5    | 55/100 [00:27<00:22,  2.00it/s, loss=0.727]
Epoch 2:  55%|#####5    | 55/100 [00:28<00:22,  2.00it/s, loss=1.17] 
Epoch 2:  56%|#####6    | 56/100 [00:28<00:22,  1.99it/s, loss=1.17]
Epoch 2:  56%|#####6    | 56/100 [00:28<00:22,  1.99it/s, loss=0.547]
Epoch 2:  57%|#####6    | 57/100 [00:28<00:21,  1.99it/s, loss=0.547]
Epoch 2:  57%|#####6    | 57/100 [00:29<00:21,  1.99it/s, loss=0.492]
Epoch 2:  58%|#####8    | 58/100 [00:29<00:21,  1.99it/s, loss=0.492]
Epoch 2:  58%|#####8    | 58/100 [00:29<00:21,  1.99it/s, loss=0.748]
Epoch 2:  59%|#####8    | 59/100 [00:29<00:20,  1.99it/s, loss=0.748]
Epoch 2:  59%|#####8    | 59/100 [00:30<00:20,  1.99it/s, loss=0.589]
Epoch 2:  60%|######    | 60/100 [00:30<00:20,  1.97it/s, loss=0.589]
Epoch 2:  60%|######    | 60/100 [00:30<00:20,  1.97it/s, loss=0.661]
Epoch 2:  61%|######1   | 61/100 [00:30<00:19,  1.97it/s, loss=0.661]
Epoch 2:  61%|######1   | 61/100 [00:31<00:19,  1.97it/s, loss=0.787]
Epoch 2:  62%|######2   | 62/100 [00:31<00:19,  1.98it/s, loss=0.787]
Epoch 2:  62%|######2   | 62/100 [00:31<00:19,  1.98it/s, loss=0.827]
Epoch 2:  63%|######3   | 63/100 [00:31<00:18,  1.98it/s, loss=0.827]
Epoch 2:  63%|######3   | 63/100 [00:32<00:18,  1.98it/s, loss=0.443]
Epoch 2:  64%|######4   | 64/100 [00:32<00:18,  1.97it/s, loss=0.443]
Epoch 2:  64%|######4   | 64/100 [00:32<00:18,  1.97it/s, loss=0.483]
Epoch 2:  65%|######5   | 65/100 [00:32<00:17,  1.98it/s, loss=0.483]
Epoch 2:  65%|######5   | 65/100 [00:33<00:17,  1.98it/s, loss=0.628]
Epoch 2:  66%|######6   | 66/100 [00:33<00:17,  1.98it/s, loss=0.628]
Epoch 2:  66%|######6   | 66/100 [00:33<00:17,  1.98it/s, loss=0.795]
Epoch 2:  67%|######7   | 67/100 [00:33<00:16,  1.99it/s, loss=0.795]
Epoch 2:  67%|######7   | 67/100 [00:34<00:16,  1.99it/s, loss=0.726]
Epoch 2:  68%|######8   | 68/100 [00:34<00:16,  2.00it/s, loss=0.726]
Epoch 2:  68%|######8   | 68/100 [00:34<00:16,  2.00it/s, loss=0.631]
Epoch 2:  69%|######9   | 69/100 [00:34<00:15,  2.00it/s, loss=0.631]
Epoch 2:  69%|######9   | 69/100 [00:35<00:15,  2.00it/s, loss=0.564]
Epoch 2:  70%|#######   | 70/100 [00:35<00:15,  2.00it/s, loss=0.564]
Epoch 2:  70%|#######   | 70/100 [00:35<00:15,  2.00it/s, loss=0.616]
Epoch 2:  71%|#######1  | 71/100 [00:35<00:14,  2.00it/s, loss=0.616]
Epoch 2:  71%|#######1  | 71/100 [00:36<00:14,  2.00it/s, loss=0.695]
Epoch 2:  72%|#######2  | 72/100 [00:36<00:14,  2.00it/s, loss=0.695]
Epoch 2:  72%|#######2  | 72/100 [00:36<00:14,  2.00it/s, loss=0.526]
Epoch 2:  73%|#######3  | 73/100 [00:36<00:13,  2.00it/s, loss=0.526]
Epoch 2:  73%|#######3  | 73/100 [00:37<00:13,  2.00it/s, loss=0.785]
Epoch 2:  74%|#######4  | 74/100 [00:37<00:13,  1.99it/s, loss=0.785]
Epoch 2:  74%|#######4  | 74/100 [00:37<00:13,  1.99it/s, loss=0.611]
Epoch 2:  75%|#######5  | 75/100 [00:37<00:12,  2.00it/s, loss=0.611]
Epoch 2:  75%|#######5  | 75/100 [00:38<00:12,  2.00it/s, loss=0.53] 
Epoch 2:  76%|#######6  | 76/100 [00:38<00:12,  2.00it/s, loss=0.53]
Epoch 2:  76%|#######6  | 76/100 [00:38<00:12,  2.00it/s, loss=0.918]
Epoch 2:  77%|#######7  | 77/100 [00:38<00:11,  2.00it/s, loss=0.918]
Epoch 2:  77%|#######7  | 77/100 [00:39<00:11,  2.00it/s, loss=0.577]
Epoch 2:  78%|#######8  | 78/100 [00:39<00:10,  2.00it/s, loss=0.577]
Epoch 2:  78%|#######8  | 78/100 [00:39<00:10,  2.00it/s, loss=0.763]
Epoch 2:  79%|#######9  | 79/100 [00:39<00:10,  2.00it/s, loss=0.763]
Epoch 2:  79%|#######9  | 79/100 [00:40<00:10,  2.00it/s, loss=0.688]
Epoch 2:  80%|########  | 80/100 [00:40<00:09,  2.01it/s, loss=0.688]
Epoch 2:  80%|########  | 80/100 [00:40<00:09,  2.01it/s, loss=0.582]
Epoch 2:  81%|########1 | 81/100 [00:40<00:09,  2.00it/s, loss=0.582]
Epoch 2:  81%|########1 | 81/100 [00:41<00:09,  2.00it/s, loss=0.532]
Epoch 2:  82%|########2 | 82/100 [00:41<00:08,  2.01it/s, loss=0.532]
Epoch 2:  82%|########2 | 82/100 [00:41<00:08,  2.01it/s, loss=0.648]
Epoch 2:  83%|########2 | 83/100 [00:41<00:08,  2.00it/s, loss=0.648]
Epoch 2:  83%|########2 | 83/100 [00:42<00:08,  2.00it/s, loss=0.506]
Epoch 2:  84%|########4 | 84/100 [00:42<00:07,  2.00it/s, loss=0.506]
Epoch 2:  84%|########4 | 84/100 [00:42<00:07,  2.00it/s, loss=0.757]
Epoch 2:  85%|########5 | 85/100 [00:42<00:07,  2.00it/s, loss=0.757]
Epoch 2:  85%|########5 | 85/100 [00:43<00:07,  2.00it/s, loss=0.709]
Epoch 2:  86%|########6 | 86/100 [00:43<00:06,  2.00it/s, loss=0.709]
Epoch 2:  86%|########6 | 86/100 [00:43<00:06,  2.00it/s, loss=0.494]
Epoch 2:  87%|########7 | 87/100 [00:43<00:06,  2.00it/s, loss=0.494]
Epoch 2:  87%|########7 | 87/100 [00:44<00:06,  2.00it/s, loss=0.887]
Epoch 2:  88%|########8 | 88/100 [00:44<00:05,  2.00it/s, loss=0.887]
Epoch 2:  88%|########8 | 88/100 [00:44<00:05,  2.00it/s, loss=0.692]
Epoch 2:  89%|########9 | 89/100 [00:44<00:05,  2.00it/s, loss=0.692]
Epoch 2:  89%|########9 | 89/100 [00:45<00:05,  2.00it/s, loss=0.847]
Epoch 2:  90%|######### | 90/100 [00:45<00:05,  2.00it/s, loss=0.847]
Epoch 2:  90%|######### | 90/100 [00:45<00:05,  2.00it/s, loss=0.563]
Epoch 2:  91%|#########1| 91/100 [00:45<00:04,  2.00it/s, loss=0.563]
Epoch 2:  91%|#########1| 91/100 [00:46<00:04,  2.00it/s, loss=0.437]
Epoch 2:  92%|#########2| 92/100 [00:46<00:04,  1.99it/s, loss=0.437]
Epoch 2:  92%|#########2| 92/100 [00:46<00:04,  1.99it/s, loss=0.593]
Epoch 2:  93%|#########3| 93/100 [00:46<00:03,  1.96it/s, loss=0.593]
Epoch 2:  93%|#########3| 93/100 [00:47<00:03,  1.96it/s, loss=0.648]
Epoch 2:  94%|#########3| 94/100 [00:47<00:03,  1.97it/s, loss=0.648]
Epoch 2:  94%|#########3| 94/100 [00:47<00:03,  1.97it/s, loss=0.428]
Epoch 2:  95%|#########5| 95/100 [00:47<00:02,  1.97it/s, loss=0.428]
Epoch 2:  95%|#########5| 95/100 [00:48<00:02,  1.97it/s, loss=0.693]
Epoch 2:  96%|#########6| 96/100 [00:48<00:02,  1.98it/s, loss=0.693]
Epoch 2:  96%|#########6| 96/100 [00:48<00:02,  1.98it/s, loss=0.468]
Epoch 2:  97%|#########7| 97/100 [00:48<00:01,  1.98it/s, loss=0.468]
Epoch 2:  97%|#########7| 97/100 [00:49<00:01,  1.98it/s, loss=0.71] 
Epoch 2:  98%|#########8| 98/100 [00:49<00:01,  1.99it/s, loss=0.71]
Epoch 2:  98%|#########8| 98/100 [00:49<00:01,  1.99it/s, loss=0.546]
Epoch 2:  99%|#########9| 99/100 [00:49<00:00,  2.00it/s, loss=0.546]
Epoch 2:  99%|#########9| 99/100 [00:50<00:00,  2.00it/s, loss=0.715]
Epoch 2: 100%|##########| 100/100 [00:50<00:00,  2.00it/s, loss=0.715]
Epoch 2: 100%|##########| 100/100 [00:50<00:00,  1.99it/s, loss=0.715]
Epoch 2, Loss: 0.7027, Val Accuracy: 0.7625

Once the training has finished, we can eyeball the results using a simple example. The example uses a clear semantic relationship between sleeping and resting to showcase the model’s capability to recognize entailment.

premise = "The cat is sleeping"
hypothesis = "The cat is resting."
prediction = predict_nli(model, tokenizer, premise, hypothesis)
print(f"Prediction: {prediction}")
Prediction: entailment

In our final evaluation chunk, we perform comprehensive testing on our held-out test set. We process all test examples through our model and calculate various performance metrics using scikit-learn’s classification report. This gives us a detailed view of the model’s performance across different relationship types, including precision, recall, and F1-score for each class.

predictions = []
for idx, row in test.iterrows():
    pred = predict_nli(model, tokenizer, row['premise'], row['hypothesis'])
    predictions.append(pred)

test['predicted'] = predictions

print(classification_report(test['label'], test['predicted']))
               precision    recall  f1-score   support

contradiction       0.76      0.69      0.72       328
   entailment       0.77      0.86      0.81       342
      neutral       0.66      0.65      0.65       330

     accuracy                           0.73      1000
    macro avg       0.73      0.73      0.73      1000
 weighted avg       0.73      0.73      0.73      1000

References

Laurer, Moritz, Wouter Van Atteveldt, Andreu Casas, and Kasper Welbers. 2024. Less Annotating, More Classifying: Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT-NLI.” Political Analysis 32(1):84–100.