needs(reticulate)
use_virtualenv("_pyenv/transformer_env")
Chapter 15: NLI using BERT
Natural Language Inference (NLI), sometimes called textual entailment, represents a fundamental task in Natural Language Processing. The goal is to determine the logical relationship between pairs of sentences – a premise and a hypothesis. Given these two pieces of text, models must decide whether the hypothesis is logically entailed by the premise, contradicts it, or is unrelated (neutral).
For example, if our premise is “The cat is sleeping on the couch” and our hypothesis is “The cat is resting”, this would be considered entailment since sleeping logically implies resting. However, if the hypothesis were “The cat is playing,” this would be a contradiction. A hypothesis like “The cat has brown fur” would be neutral – we cannot determine its truth value from the premise alone.
This task serves as an important benchmark for evaluating language understanding in AI systems. Unlike simpler tasks that can be solved through pattern matching or keyword detection, NLI requires deeper semantic comprehension. A model must understand context, implications, and common sense relationships between concepts.
Pre-trained language models like BERT have shown remarkable capabilities in NLI tasks. Through their pre-training on vast amounts of text, they develop representations that capture subtle semantic relationships. We can fine-tune these models on NLI datasets to create powerful classifiers for textual relationships. Laurer et al. (2024) have suggested using NLI for classification tasks all together, as they might outperform fine-tuned classifiers for classification tasks.
In the following sections, we will implement an NLI system using BERT. We’ll see how to prepare sentence pairs for the model, create an appropriate architecture for the task, and train the system to recognize these logical relationships. The same principles we explored in simpler text classification tasks will apply here, but with additional complexity in both the data preparation and model architecture. I aimed to prepare everything in such a way that you can copy-paste the code. In your own scripts, you will have to define the same classes and your annotated data should follow the same structure – i.e., have two list columns named “premise” and “hypothesis” and another column named “label.” The label column should be strings containing “contradiction,” “neutral,” or “entailment”. I highly suggest going with string labels to avoid confusion about the predictions.
Set up Python and classes
In the first chunk, we set up our Python–R connection through the reticulate
package and specify our virtual environment that we have set up before. It is important that it contains all the relevant packages. For reference on how to set up the venv, check chapter 14 on BERT.
Then, we import all necessary libraries and set up our device configuration. The device setup is particularly important as it allows our code to run efficiently on different hardware configurations – whether that’s a Silicon Mac using MPS (in my case), a machine with CUDA-enabled GPU, or a regular CPU. Depending on this, you might have to install the respective torch
packages. We use pandas
for data manipulation, torch
(PyTorch) for deep learning operations, and the transformers
library for access to pre-trained BERT models. Furthermore, we use an array of sklearn
(scikit-learn) functions for train-test split creation and subsequent model evaluation.
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm.auto import tqdm
# First set up the device
def get_device():
if torch.backends.mps.is_available():
= torch.device("mps")
device elif torch.cuda.is_available():
= torch.device("cuda")
device else:
= torch.device("cpu")
device return device
= get_device()
device print(f"Using device: {device}")
Using device: mps
Next, we define our custom NLIDataset
class that handles the specific requirements of Natural Language Inference data. Classes in Python provide a way to bundle data and functionality together. They act as blueprints for creating objects that share similar characteristics and behaviors. Think of a class like a template - it defines what type of data the object can store (through attributes) and what operations it can perform (through methods).
Here, several things happen: first, Python creates a new object based on the NLIDataset template we define, the __init__
method is used for constructing the set. Then, the object gets its own copy of the attributes. Finally, we can call methods on this object – e.g., dataset.get_text(0). Think of it as preprocessing.
This class processes pairs of sentences (premises and hypotheses) along with their labels. It includes robust error handling for different label formats and converts all inputs into the tensor format required by PyTorch
. The max_length
parameter controls sequence length, with longer sequences being truncated and shorter ones padded.
# Define the Dataset class
class NLIDataset(Dataset):
def __init__(self, df, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.premises = df['premise'].tolist()
self.hypotheses = df['hypothesis'].tolist()
self.max_length = max_length
try:
# Try to convert to numeric and verify range
= pd.to_numeric(df['label'])
numeric_labels if not all(numeric_labels.between(0, 2)):
raise ValueError("Numeric labels must be in range [0,2]")
self.labels = numeric_labels.tolist()
except ValueError:
# If conversion to numeric fails, treat as string labels
self.label_map = {'contradiction': 0, 'neutral': 1, 'entailment': 2}
try:
self.labels = [self.label_map[label] for label in df['label']]
except KeyError as e:
raise ValueError(f"Invalid label found. Labels must be either numbers [0,2] or one of {list(self.label_map.keys())}")
def __len__(self):
return len(self.premises)
def __getitem__(self, idx):
= str(self.premises[idx])
premise = str(self.hypotheses[idx])
hypothesis = self.labels[idx]
label
= self.tokenizer(
encoding
premise,
hypothesis,='max_length',
padding=True,
truncation=self.max_length,
max_length='pt'
return_tensors
)
return {
'input_ids': encoding['input_ids'].squeeze(),
'attention_mask': encoding['attention_mask'].squeeze(),
'token_type_ids': encoding['token_type_ids'].squeeze(),
'labels': torch.tensor(label)
}
Then we can create the prepare_nli_data
function streamlines the data preparation process. It splits the data into training and validation sets (validation sets are used during training for intermediary evaluation during training, to avoid overfitting), creates Dataset
objects, and wraps them in DataLoader
instances. The DataLoader
s handle batching and shuffling during training, with customizable batch size and test split ratio. This serves as an organized data pipeline to make sure we get efficient training and validation processes.
# Define the data preparation function
def prepare_nli_data(df, tokenizer, test_size=0.2, batch_size=16):
= train_test_split(df, test_size=test_size, random_state=42)
train_df, val_df
= NLIDataset(train_df, tokenizer)
train_dataset = NLIDataset(val_df, tokenizer)
val_dataset
= DataLoader(
train_dataloader
train_dataset, =batch_size,
batch_size=True
shuffle
)= DataLoader(
val_dataloader
val_dataset, =batch_size
batch_size
)
return train_dataloader, val_dataloader
The BertForNLI
class defines our model architecture for Natural Language Inference. It builds upon the pre-trained BERT model, adding a dropout layer for regularization, and finally a classification head for our three-class prediction task. The forward method handles the flow of data through the model, utilizing BERT’s special tokens and attention mechanisms for processing sentence pairs.
# Define the model class
class BertForNLI(nn.Module):
def __init__(self, num_labels=3):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(768, num_labels)
def forward(self, input_ids, attention_mask, token_type_ids):
= self.bert(
outputs =input_ids,
input_ids=attention_mask,
attention_mask=token_type_ids
token_type_ids
)
= outputs.pooler_output
pooled_output = self.dropout(pooled_output)
pooled_output = self.classifier(pooled_output)
logits
return logits
In the training function chunk, we implement the complete training loop with progress tracking, wrapped in the function train_nli_model
. This function handles the optimization process using AdamW
, monitors loss during training, and evaluates performance on the validation set after each epoch. The progress bar provides real-time feedback during training, making it easier to monitor long training runs.
# Define the training function
def train_nli_model(model, train_dataloader, val_dataloader, epochs=3):
= AdamW(model.parameters(), lr=2e-5)
optimizer
for epoch in range(epochs):
model.train()= 0
total_loss = tqdm(train_dataloader, desc=f'Epoch {epoch+1}')
progress_bar
for batch in progress_bar:
optimizer.zero_grad()
= batch['input_ids'].to(device)
input_ids = batch['attention_mask'].to(device)
attention_mask = batch['token_type_ids'].to(device)
token_type_ids = batch['labels'].to(device)
labels
= model(
outputs =input_ids,
input_ids=attention_mask,
attention_mask=token_type_ids
token_type_ids
)
= F.cross_entropy(outputs, labels)
loss += loss.item()
total_loss
loss.backward()
optimizer.step()
'loss': loss.item()})
progress_bar.set_postfix({
= evaluate_nli_model(model, val_dataloader)
val_accuracy print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_dataloader):.4f}, Val Accuracy: {val_accuracy:.4f}")
Finally, the evaluation chunk contains two crucial functions. The evaluate_nli_model
function calculates model accuracy on a given data set, while predict_nli
handles predictions for new sentence pairs. These functions include proper model state management (evaluation mode) and device handling, ensuring consistent results whether running on CPU or GPU.
def evaluate_nli_model(model, dataloader):
eval()
model.= 0
correct = 0
total
with torch.no_grad():
for batch in dataloader:
= batch['input_ids'].to(device)
input_ids = batch['attention_mask'].to(device)
attention_mask = batch['token_type_ids'].to(device)
token_type_ids = batch['labels'].to(device)
labels
= model(
outputs =input_ids,
input_ids=attention_mask,
attention_mask=token_type_ids
token_type_ids
)
= torch.max(outputs, 1)
_, predicted += labels.size(0)
total += (predicted == labels).sum().item()
correct
return correct / total
# predict label for new examples
def predict_nli(model, tokenizer, premise, hypothesis):
eval()
model.= tokenizer(
encoding
premise,
hypothesis,='max_length',
padding=True,
truncation=128,
max_length='pt'
return_tensors
)
= encoding['input_ids'].to(device)
input_ids = encoding['attention_mask'].to(device)
attention_mask = encoding['token_type_ids'].to(device)
token_type_ids
with torch.no_grad():
= model(
outputs =input_ids,
input_ids=attention_mask,
attention_mask=token_type_ids
token_type_ids
)= torch.argmax(outputs, dim=1)
prediction
= {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
label_map return label_map[prediction.item()]
Training and Evaluating on our own Data
Finally, we have set everything up so that we can work on our own data. First, we read in and prepare prepare our actual data, a data set I downloaded from Kaggle, i.e., reading from a TSV file and replace the numeric labels with strings to avoid confusion. We create training and test sets from the data. This structured approach to data organization is crucial for maintaining consistency throughout the training and evaluation process.
= pd.read_csv('./files/pair-class_dev.tsv', sep='\t')
df = {
label_map 0: 'entailment',
1: 'neutral',
2: 'contradiction'
}'label'] = df['label'].replace(label_map)
df[= df[0:2000]
train = df[2000:3000].copy() test
Then we create our tokenizer and model instances. The tokenizer is loaded from the pre-trained BERT model, and our custom NLI model is moved to the appropriate device (i.e., GPU, CPU, or, in my case, MPS) for computation. This setup forms the foundation for our training process.
= BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertForNLI()
model = model.to(device) model
The training execution chunk brings everything together, creating our data loaders and running the training process. We specify two epochs for training, though this parameter could be adjusted based on the specific requirements of the task and the observed learning dynamics (more is generally better).
= prepare_nli_data(train, tokenizer)
train_dataloader, val_dataloader = 2) train_nli_model(model, train_dataloader, val_dataloader, epochs
Epoch 1: 0%| | 0/100 [00:00<?, ?it/s]
Epoch 1: 0%| | 0/100 [00:01<?, ?it/s, loss=1.19]
Epoch 1: 1%|1 | 1/100 [00:01<02:09, 1.31s/it, loss=1.19]
Epoch 1: 1%|1 | 1/100 [00:02<02:09, 1.31s/it, loss=1.14]
Epoch 1: 2%|2 | 2/100 [00:02<01:33, 1.04it/s, loss=1.14]
Epoch 1: 2%|2 | 2/100 [00:02<01:33, 1.04it/s, loss=1.09]
Epoch 1: 3%|3 | 3/100 [00:02<01:12, 1.34it/s, loss=1.09]
Epoch 1: 3%|3 | 3/100 [00:03<01:12, 1.34it/s, loss=1.15]
Epoch 1: 4%|4 | 4/100 [00:03<01:02, 1.54it/s, loss=1.15]
Epoch 1: 4%|4 | 4/100 [00:03<01:02, 1.54it/s, loss=1.11]
Epoch 1: 5%|5 | 5/100 [00:03<00:56, 1.68it/s, loss=1.11]
Epoch 1: 5%|5 | 5/100 [00:04<00:56, 1.68it/s, loss=1.11]
Epoch 1: 6%|6 | 6/100 [00:04<00:52, 1.78it/s, loss=1.11]
Epoch 1: 6%|6 | 6/100 [00:04<00:52, 1.78it/s, loss=1.13]
Epoch 1: 7%|7 | 7/100 [00:04<00:50, 1.84it/s, loss=1.13]
Epoch 1: 7%|7 | 7/100 [00:05<00:50, 1.84it/s, loss=1.02]
Epoch 1: 8%|8 | 8/100 [00:05<00:48, 1.88it/s, loss=1.02]
Epoch 1: 8%|8 | 8/100 [00:05<00:48, 1.88it/s, loss=1.01]
Epoch 1: 9%|9 | 9/100 [00:05<00:47, 1.92it/s, loss=1.01]
Epoch 1: 9%|9 | 9/100 [00:06<00:47, 1.92it/s, loss=1.01]
Epoch 1: 10%|# | 10/100 [00:06<00:46, 1.94it/s, loss=1.01]
Epoch 1: 10%|# | 10/100 [00:06<00:46, 1.94it/s, loss=1.2]
Epoch 1: 11%|#1 | 11/100 [00:06<00:45, 1.96it/s, loss=1.2]
Epoch 1: 11%|#1 | 11/100 [00:07<00:45, 1.96it/s, loss=1.01]
Epoch 1: 12%|#2 | 12/100 [00:07<00:44, 1.97it/s, loss=1.01]
Epoch 1: 12%|#2 | 12/100 [00:07<00:44, 1.97it/s, loss=1.11]
Epoch 1: 13%|#3 | 13/100 [00:07<00:44, 1.97it/s, loss=1.11]
Epoch 1: 13%|#3 | 13/100 [00:08<00:44, 1.97it/s, loss=1.11]
Epoch 1: 14%|#4 | 14/100 [00:08<00:43, 1.98it/s, loss=1.11]
Epoch 1: 14%|#4 | 14/100 [00:08<00:43, 1.98it/s, loss=0.995]
Epoch 1: 15%|#5 | 15/100 [00:08<00:42, 1.98it/s, loss=0.995]
Epoch 1: 15%|#5 | 15/100 [00:09<00:42, 1.98it/s, loss=1.11]
Epoch 1: 16%|#6 | 16/100 [00:09<00:42, 1.98it/s, loss=1.11]
Epoch 1: 16%|#6 | 16/100 [00:09<00:42, 1.98it/s, loss=1.28]
Epoch 1: 17%|#7 | 17/100 [00:09<00:41, 1.99it/s, loss=1.28]
Epoch 1: 17%|#7 | 17/100 [00:10<00:41, 1.99it/s, loss=1.05]
Epoch 1: 18%|#8 | 18/100 [00:10<00:41, 1.99it/s, loss=1.05]
Epoch 1: 18%|#8 | 18/100 [00:10<00:41, 1.99it/s, loss=1.08]
Epoch 1: 19%|#9 | 19/100 [00:10<00:40, 1.99it/s, loss=1.08]
Epoch 1: 19%|#9 | 19/100 [00:11<00:40, 1.99it/s, loss=1.16]
Epoch 1: 20%|## | 20/100 [00:11<00:40, 1.98it/s, loss=1.16]
Epoch 1: 20%|## | 20/100 [00:11<00:40, 1.98it/s, loss=1.06]
Epoch 1: 21%|##1 | 21/100 [00:11<00:39, 1.99it/s, loss=1.06]
Epoch 1: 21%|##1 | 21/100 [00:12<00:39, 1.99it/s, loss=0.936]
Epoch 1: 22%|##2 | 22/100 [00:12<00:39, 1.99it/s, loss=0.936]
Epoch 1: 22%|##2 | 22/100 [00:12<00:39, 1.99it/s, loss=1.15]
Epoch 1: 23%|##3 | 23/100 [00:12<00:38, 2.00it/s, loss=1.15]
Epoch 1: 23%|##3 | 23/100 [00:13<00:38, 2.00it/s, loss=1.12]
Epoch 1: 24%|##4 | 24/100 [00:13<00:37, 2.00it/s, loss=1.12]
Epoch 1: 24%|##4 | 24/100 [00:13<00:37, 2.00it/s, loss=1.08]
Epoch 1: 25%|##5 | 25/100 [00:13<00:37, 2.01it/s, loss=1.08]
Epoch 1: 25%|##5 | 25/100 [00:14<00:37, 2.01it/s, loss=1.05]
Epoch 1: 26%|##6 | 26/100 [00:14<00:36, 2.01it/s, loss=1.05]
Epoch 1: 26%|##6 | 26/100 [00:14<00:36, 2.01it/s, loss=1.07]
Epoch 1: 27%|##7 | 27/100 [00:14<00:36, 2.01it/s, loss=1.07]
Epoch 1: 27%|##7 | 27/100 [00:15<00:36, 2.01it/s, loss=1.11]
Epoch 1: 28%|##8 | 28/100 [00:15<00:35, 2.01it/s, loss=1.11]
Epoch 1: 28%|##8 | 28/100 [00:15<00:35, 2.01it/s, loss=0.974]
Epoch 1: 29%|##9 | 29/100 [00:15<00:35, 2.01it/s, loss=0.974]
Epoch 1: 29%|##9 | 29/100 [00:16<00:35, 2.01it/s, loss=1.07]
Epoch 1: 30%|### | 30/100 [00:16<00:34, 2.01it/s, loss=1.07]
Epoch 1: 30%|### | 30/100 [00:16<00:34, 2.01it/s, loss=1.15]
Epoch 1: 31%|###1 | 31/100 [00:16<00:34, 2.01it/s, loss=1.15]
Epoch 1: 31%|###1 | 31/100 [00:17<00:34, 2.01it/s, loss=1.02]
Epoch 1: 32%|###2 | 32/100 [00:17<00:33, 2.01it/s, loss=1.02]
Epoch 1: 32%|###2 | 32/100 [00:17<00:33, 2.01it/s, loss=1.14]
Epoch 1: 33%|###3 | 33/100 [00:17<00:33, 2.02it/s, loss=1.14]
Epoch 1: 33%|###3 | 33/100 [00:18<00:33, 2.02it/s, loss=1.06]
Epoch 1: 34%|###4 | 34/100 [00:18<00:32, 2.02it/s, loss=1.06]
Epoch 1: 34%|###4 | 34/100 [00:18<00:32, 2.02it/s, loss=0.974]
Epoch 1: 35%|###5 | 35/100 [00:18<00:32, 2.02it/s, loss=0.974]
Epoch 1: 35%|###5 | 35/100 [00:19<00:32, 2.02it/s, loss=0.934]
Epoch 1: 36%|###6 | 36/100 [00:19<00:31, 2.01it/s, loss=0.934]
Epoch 1: 36%|###6 | 36/100 [00:19<00:31, 2.01it/s, loss=1.04]
Epoch 1: 37%|###7 | 37/100 [00:19<00:31, 2.01it/s, loss=1.04]
Epoch 1: 37%|###7 | 37/100 [00:20<00:31, 2.01it/s, loss=1.23]
Epoch 1: 38%|###8 | 38/100 [00:20<00:30, 2.00it/s, loss=1.23]
Epoch 1: 38%|###8 | 38/100 [00:20<00:30, 2.00it/s, loss=1.07]
Epoch 1: 39%|###9 | 39/100 [00:20<00:30, 2.01it/s, loss=1.07]
Epoch 1: 39%|###9 | 39/100 [00:21<00:30, 2.01it/s, loss=1.03]
Epoch 1: 40%|#### | 40/100 [00:21<00:29, 2.00it/s, loss=1.03]
Epoch 1: 40%|#### | 40/100 [00:21<00:29, 2.00it/s, loss=1.09]
Epoch 1: 41%|####1 | 41/100 [00:21<00:29, 2.01it/s, loss=1.09]
Epoch 1: 41%|####1 | 41/100 [00:22<00:29, 2.01it/s, loss=1.05]
Epoch 1: 42%|####2 | 42/100 [00:22<00:28, 2.00it/s, loss=1.05]
Epoch 1: 42%|####2 | 42/100 [00:22<00:28, 2.00it/s, loss=1.09]
Epoch 1: 43%|####3 | 43/100 [00:22<00:28, 2.00it/s, loss=1.09]
Epoch 1: 43%|####3 | 43/100 [00:23<00:28, 2.00it/s, loss=0.994]
Epoch 1: 44%|####4 | 44/100 [00:23<00:28, 2.00it/s, loss=0.994]
Epoch 1: 44%|####4 | 44/100 [00:23<00:28, 2.00it/s, loss=0.979]
Epoch 1: 45%|####5 | 45/100 [00:23<00:27, 2.00it/s, loss=0.979]
Epoch 1: 45%|####5 | 45/100 [00:24<00:27, 2.00it/s, loss=0.965]
Epoch 1: 46%|####6 | 46/100 [00:24<00:26, 2.00it/s, loss=0.965]
Epoch 1: 46%|####6 | 46/100 [00:24<00:26, 2.00it/s, loss=1.11]
Epoch 1: 47%|####6 | 47/100 [00:24<00:26, 2.00it/s, loss=1.11]
Epoch 1: 47%|####6 | 47/100 [00:25<00:26, 2.00it/s, loss=0.864]
Epoch 1: 48%|####8 | 48/100 [00:25<00:25, 2.01it/s, loss=0.864]
Epoch 1: 48%|####8 | 48/100 [00:25<00:25, 2.01it/s, loss=1.15]
Epoch 1: 49%|####9 | 49/100 [00:25<00:25, 2.00it/s, loss=1.15]
Epoch 1: 49%|####9 | 49/100 [00:26<00:25, 2.00it/s, loss=1.16]
Epoch 1: 50%|##### | 50/100 [00:26<00:25, 1.99it/s, loss=1.16]
Epoch 1: 50%|##### | 50/100 [00:26<00:25, 1.99it/s, loss=1.07]
Epoch 1: 51%|#####1 | 51/100 [00:26<00:24, 2.00it/s, loss=1.07]
Epoch 1: 51%|#####1 | 51/100 [00:27<00:24, 2.00it/s, loss=1.02]
Epoch 1: 52%|#####2 | 52/100 [00:27<00:24, 2.00it/s, loss=1.02]
Epoch 1: 52%|#####2 | 52/100 [00:27<00:24, 2.00it/s, loss=1.16]
Epoch 1: 53%|#####3 | 53/100 [00:27<00:23, 2.00it/s, loss=1.16]
Epoch 1: 53%|#####3 | 53/100 [00:28<00:23, 2.00it/s, loss=0.952]
Epoch 1: 54%|#####4 | 54/100 [00:28<00:22, 2.00it/s, loss=0.952]
Epoch 1: 54%|#####4 | 54/100 [00:28<00:22, 2.00it/s, loss=0.919]
Epoch 1: 55%|#####5 | 55/100 [00:28<00:22, 2.00it/s, loss=0.919]
Epoch 1: 55%|#####5 | 55/100 [00:29<00:22, 2.00it/s, loss=1]
Epoch 1: 56%|#####6 | 56/100 [00:29<00:22, 2.00it/s, loss=1]
Epoch 1: 56%|#####6 | 56/100 [00:29<00:22, 2.00it/s, loss=0.87]
Epoch 1: 57%|#####6 | 57/100 [00:29<00:21, 2.00it/s, loss=0.87]
Epoch 1: 57%|#####6 | 57/100 [00:30<00:21, 2.00it/s, loss=0.892]
Epoch 1: 58%|#####8 | 58/100 [00:30<00:21, 1.99it/s, loss=0.892]
Epoch 1: 58%|#####8 | 58/100 [00:30<00:21, 1.99it/s, loss=0.865]
Epoch 1: 59%|#####8 | 59/100 [00:30<00:20, 1.99it/s, loss=0.865]
Epoch 1: 59%|#####8 | 59/100 [00:31<00:20, 1.99it/s, loss=0.912]
Epoch 1: 60%|###### | 60/100 [00:31<00:20, 2.00it/s, loss=0.912]
Epoch 1: 60%|###### | 60/100 [00:31<00:20, 2.00it/s, loss=0.885]
Epoch 1: 61%|######1 | 61/100 [00:31<00:19, 2.00it/s, loss=0.885]
Epoch 1: 61%|######1 | 61/100 [00:32<00:19, 2.00it/s, loss=1.06]
Epoch 1: 62%|######2 | 62/100 [00:32<00:19, 2.00it/s, loss=1.06]
Epoch 1: 62%|######2 | 62/100 [00:32<00:19, 2.00it/s, loss=0.885]
Epoch 1: 63%|######3 | 63/100 [00:32<00:18, 1.99it/s, loss=0.885]
Epoch 1: 63%|######3 | 63/100 [00:33<00:18, 1.99it/s, loss=0.889]
Epoch 1: 64%|######4 | 64/100 [00:33<00:18, 2.00it/s, loss=0.889]
Epoch 1: 64%|######4 | 64/100 [00:33<00:18, 2.00it/s, loss=0.997]
Epoch 1: 65%|######5 | 65/100 [00:33<00:17, 2.00it/s, loss=0.997]
Epoch 1: 65%|######5 | 65/100 [00:34<00:17, 2.00it/s, loss=0.906]
Epoch 1: 66%|######6 | 66/100 [00:34<00:16, 2.00it/s, loss=0.906]
Epoch 1: 66%|######6 | 66/100 [00:34<00:16, 2.00it/s, loss=0.923]
Epoch 1: 67%|######7 | 67/100 [00:34<00:16, 2.00it/s, loss=0.923]
Epoch 1: 67%|######7 | 67/100 [00:35<00:16, 2.00it/s, loss=1.02]
Epoch 1: 68%|######8 | 68/100 [00:35<00:16, 2.00it/s, loss=1.02]
Epoch 1: 68%|######8 | 68/100 [00:35<00:16, 2.00it/s, loss=0.876]
Epoch 1: 69%|######9 | 69/100 [00:35<00:15, 2.00it/s, loss=0.876]
Epoch 1: 69%|######9 | 69/100 [00:36<00:15, 2.00it/s, loss=0.885]
Epoch 1: 70%|####### | 70/100 [00:36<00:15, 1.99it/s, loss=0.885]
Epoch 1: 70%|####### | 70/100 [00:36<00:15, 1.99it/s, loss=1.01]
Epoch 1: 71%|#######1 | 71/100 [00:36<00:14, 2.00it/s, loss=1.01]
Epoch 1: 71%|#######1 | 71/100 [00:37<00:14, 2.00it/s, loss=0.935]
Epoch 1: 72%|#######2 | 72/100 [00:37<00:14, 2.00it/s, loss=0.935]
Epoch 1: 72%|#######2 | 72/100 [00:37<00:14, 2.00it/s, loss=0.974]
Epoch 1: 73%|#######3 | 73/100 [00:37<00:13, 2.00it/s, loss=0.974]
Epoch 1: 73%|#######3 | 73/100 [00:38<00:13, 2.00it/s, loss=0.849]
Epoch 1: 74%|#######4 | 74/100 [00:38<00:12, 2.00it/s, loss=0.849]
Epoch 1: 74%|#######4 | 74/100 [00:38<00:12, 2.00it/s, loss=0.938]
Epoch 1: 75%|#######5 | 75/100 [00:38<00:12, 2.00it/s, loss=0.938]
Epoch 1: 75%|#######5 | 75/100 [00:39<00:12, 2.00it/s, loss=0.835]
Epoch 1: 76%|#######6 | 76/100 [00:39<00:11, 2.01it/s, loss=0.835]
Epoch 1: 76%|#######6 | 76/100 [00:39<00:11, 2.01it/s, loss=1.14]
Epoch 1: 77%|#######7 | 77/100 [00:39<00:11, 2.01it/s, loss=1.14]
Epoch 1: 77%|#######7 | 77/100 [00:40<00:11, 2.01it/s, loss=0.964]
Epoch 1: 78%|#######8 | 78/100 [00:40<00:10, 2.00it/s, loss=0.964]
Epoch 1: 78%|#######8 | 78/100 [00:40<00:10, 2.00it/s, loss=0.765]
Epoch 1: 79%|#######9 | 79/100 [00:40<00:10, 2.01it/s, loss=0.765]
Epoch 1: 79%|#######9 | 79/100 [00:41<00:10, 2.01it/s, loss=0.906]
Epoch 1: 80%|######## | 80/100 [00:41<00:09, 2.01it/s, loss=0.906]
Epoch 1: 80%|######## | 80/100 [00:41<00:09, 2.01it/s, loss=0.776]
Epoch 1: 81%|########1 | 81/100 [00:41<00:09, 2.00it/s, loss=0.776]
Epoch 1: 81%|########1 | 81/100 [00:42<00:09, 2.00it/s, loss=0.866]
Epoch 1: 82%|########2 | 82/100 [00:42<00:09, 2.00it/s, loss=0.866]
Epoch 1: 82%|########2 | 82/100 [00:42<00:09, 2.00it/s, loss=0.848]
Epoch 1: 83%|########2 | 83/100 [00:42<00:08, 2.00it/s, loss=0.848]
Epoch 1: 83%|########2 | 83/100 [00:43<00:08, 2.00it/s, loss=0.968]
Epoch 1: 84%|########4 | 84/100 [00:43<00:08, 2.00it/s, loss=0.968]
Epoch 1: 84%|########4 | 84/100 [00:43<00:08, 2.00it/s, loss=0.923]
Epoch 1: 85%|########5 | 85/100 [00:43<00:07, 2.00it/s, loss=0.923]
Epoch 1: 85%|########5 | 85/100 [00:44<00:07, 2.00it/s, loss=0.926]
Epoch 1: 86%|########6 | 86/100 [00:44<00:06, 2.00it/s, loss=0.926]
Epoch 1: 86%|########6 | 86/100 [00:44<00:06, 2.00it/s, loss=0.951]
Epoch 1: 87%|########7 | 87/100 [00:44<00:06, 2.00it/s, loss=0.951]
Epoch 1: 87%|########7 | 87/100 [00:45<00:06, 2.00it/s, loss=0.768]
Epoch 1: 88%|########8 | 88/100 [00:45<00:06, 2.00it/s, loss=0.768]
Epoch 1: 88%|########8 | 88/100 [00:45<00:06, 2.00it/s, loss=0.939]
Epoch 1: 89%|########9 | 89/100 [00:45<00:05, 1.99it/s, loss=0.939]
Epoch 1: 89%|########9 | 89/100 [00:46<00:05, 1.99it/s, loss=0.702]
Epoch 1: 90%|######### | 90/100 [00:46<00:05, 2.00it/s, loss=0.702]
Epoch 1: 90%|######### | 90/100 [00:46<00:05, 2.00it/s, loss=0.884]
Epoch 1: 91%|#########1| 91/100 [00:46<00:04, 2.00it/s, loss=0.884]
Epoch 1: 91%|#########1| 91/100 [00:47<00:04, 2.00it/s, loss=0.83]
Epoch 1: 92%|#########2| 92/100 [00:47<00:04, 2.00it/s, loss=0.83]
Epoch 1: 92%|#########2| 92/100 [00:47<00:04, 2.00it/s, loss=0.753]
Epoch 1: 93%|#########3| 93/100 [00:47<00:03, 2.00it/s, loss=0.753]
Epoch 1: 93%|#########3| 93/100 [00:48<00:03, 2.00it/s, loss=0.819]
Epoch 1: 94%|#########3| 94/100 [00:48<00:03, 2.00it/s, loss=0.819]
Epoch 1: 94%|#########3| 94/100 [00:48<00:03, 2.00it/s, loss=1.07]
Epoch 1: 95%|#########5| 95/100 [00:48<00:02, 2.01it/s, loss=1.07]
Epoch 1: 95%|#########5| 95/100 [00:49<00:02, 2.01it/s, loss=0.794]
Epoch 1: 96%|#########6| 96/100 [00:49<00:01, 2.00it/s, loss=0.794]
Epoch 1: 96%|#########6| 96/100 [00:49<00:01, 2.00it/s, loss=0.879]
Epoch 1: 97%|#########7| 97/100 [00:49<00:01, 2.00it/s, loss=0.879]
Epoch 1: 97%|#########7| 97/100 [00:50<00:01, 2.00it/s, loss=1.02]
Epoch 1: 98%|#########8| 98/100 [00:50<00:00, 2.01it/s, loss=1.02]
Epoch 1: 98%|#########8| 98/100 [00:50<00:00, 2.01it/s, loss=1.13]
Epoch 1: 99%|#########9| 99/100 [00:50<00:00, 2.01it/s, loss=1.13]
Epoch 1: 99%|#########9| 99/100 [00:51<00:00, 2.01it/s, loss=0.659]
Epoch 1: 100%|##########| 100/100 [00:51<00:00, 1.99it/s, loss=0.659]
Epoch 1: 100%|##########| 100/100 [00:51<00:00, 1.96it/s, loss=0.659]
Epoch 1, Loss: 0.9956, Val Accuracy: 0.5800
Epoch 2: 0%| | 0/100 [00:00<?, ?it/s]
Epoch 2: 0%| | 0/100 [00:00<?, ?it/s, loss=0.545]
Epoch 2: 1%|1 | 1/100 [00:00<00:51, 1.94it/s, loss=0.545]
Epoch 2: 1%|1 | 1/100 [00:01<00:51, 1.94it/s, loss=0.955]
Epoch 2: 2%|2 | 2/100 [00:01<00:50, 1.96it/s, loss=0.955]
Epoch 2: 2%|2 | 2/100 [00:01<00:50, 1.96it/s, loss=0.941]
Epoch 2: 3%|3 | 3/100 [00:01<00:48, 1.98it/s, loss=0.941]
Epoch 2: 3%|3 | 3/100 [00:02<00:48, 1.98it/s, loss=0.756]
Epoch 2: 4%|4 | 4/100 [00:02<00:48, 1.99it/s, loss=0.756]
Epoch 2: 4%|4 | 4/100 [00:02<00:48, 1.99it/s, loss=0.818]
Epoch 2: 5%|5 | 5/100 [00:02<00:47, 1.99it/s, loss=0.818]
Epoch 2: 5%|5 | 5/100 [00:03<00:47, 1.99it/s, loss=1.19]
Epoch 2: 6%|6 | 6/100 [00:03<00:47, 2.00it/s, loss=1.19]
Epoch 2: 6%|6 | 6/100 [00:03<00:47, 2.00it/s, loss=0.646]
Epoch 2: 7%|7 | 7/100 [00:03<00:46, 1.99it/s, loss=0.646]
Epoch 2: 7%|7 | 7/100 [00:04<00:46, 1.99it/s, loss=0.747]
Epoch 2: 8%|8 | 8/100 [00:04<00:46, 1.99it/s, loss=0.747]
Epoch 2: 8%|8 | 8/100 [00:04<00:46, 1.99it/s, loss=0.809]
Epoch 2: 9%|9 | 9/100 [00:04<00:45, 1.99it/s, loss=0.809]
Epoch 2: 9%|9 | 9/100 [00:05<00:45, 1.99it/s, loss=0.638]
Epoch 2: 10%|# | 10/100 [00:05<00:45, 1.99it/s, loss=0.638]
Epoch 2: 10%|# | 10/100 [00:05<00:45, 1.99it/s, loss=0.731]
Epoch 2: 11%|#1 | 11/100 [00:05<00:44, 1.98it/s, loss=0.731]
Epoch 2: 11%|#1 | 11/100 [00:06<00:44, 1.98it/s, loss=0.731]
Epoch 2: 12%|#2 | 12/100 [00:06<00:44, 1.97it/s, loss=0.731]
Epoch 2: 12%|#2 | 12/100 [00:06<00:44, 1.97it/s, loss=0.865]
Epoch 2: 13%|#3 | 13/100 [00:06<00:44, 1.98it/s, loss=0.865]
Epoch 2: 13%|#3 | 13/100 [00:07<00:44, 1.98it/s, loss=0.81]
Epoch 2: 14%|#4 | 14/100 [00:07<00:43, 1.98it/s, loss=0.81]
Epoch 2: 14%|#4 | 14/100 [00:07<00:43, 1.98it/s, loss=0.832]
Epoch 2: 15%|#5 | 15/100 [00:07<00:42, 1.98it/s, loss=0.832]
Epoch 2: 15%|#5 | 15/100 [00:08<00:42, 1.98it/s, loss=0.715]
Epoch 2: 16%|#6 | 16/100 [00:08<00:42, 1.99it/s, loss=0.715]
Epoch 2: 16%|#6 | 16/100 [00:08<00:42, 1.99it/s, loss=0.669]
Epoch 2: 17%|#7 | 17/100 [00:08<00:41, 2.00it/s, loss=0.669]
Epoch 2: 17%|#7 | 17/100 [00:09<00:41, 2.00it/s, loss=0.748]
Epoch 2: 18%|#8 | 18/100 [00:09<00:41, 2.00it/s, loss=0.748]
Epoch 2: 18%|#8 | 18/100 [00:09<00:41, 2.00it/s, loss=0.756]
Epoch 2: 19%|#9 | 19/100 [00:09<00:40, 1.99it/s, loss=0.756]
Epoch 2: 19%|#9 | 19/100 [00:10<00:40, 1.99it/s, loss=0.855]
Epoch 2: 20%|## | 20/100 [00:10<00:40, 2.00it/s, loss=0.855]
Epoch 2: 20%|## | 20/100 [00:10<00:40, 2.00it/s, loss=0.624]
Epoch 2: 21%|##1 | 21/100 [00:10<00:39, 1.99it/s, loss=0.624]
Epoch 2: 21%|##1 | 21/100 [00:11<00:39, 1.99it/s, loss=0.707]
Epoch 2: 22%|##2 | 22/100 [00:11<00:39, 1.99it/s, loss=0.707]
Epoch 2: 22%|##2 | 22/100 [00:11<00:39, 1.99it/s, loss=1.03]
Epoch 2: 23%|##3 | 23/100 [00:11<00:38, 1.99it/s, loss=1.03]
Epoch 2: 23%|##3 | 23/100 [00:12<00:38, 1.99it/s, loss=0.973]
Epoch 2: 24%|##4 | 24/100 [00:12<00:38, 1.99it/s, loss=0.973]
Epoch 2: 24%|##4 | 24/100 [00:12<00:38, 1.99it/s, loss=0.618]
Epoch 2: 25%|##5 | 25/100 [00:12<00:37, 1.99it/s, loss=0.618]
Epoch 2: 25%|##5 | 25/100 [00:13<00:37, 1.99it/s, loss=0.783]
Epoch 2: 26%|##6 | 26/100 [00:13<00:37, 1.99it/s, loss=0.783]
Epoch 2: 26%|##6 | 26/100 [00:13<00:37, 1.99it/s, loss=0.596]
Epoch 2: 27%|##7 | 27/100 [00:13<00:36, 1.99it/s, loss=0.596]
Epoch 2: 27%|##7 | 27/100 [00:14<00:36, 1.99it/s, loss=0.919]
Epoch 2: 28%|##8 | 28/100 [00:14<00:36, 2.00it/s, loss=0.919]
Epoch 2: 28%|##8 | 28/100 [00:14<00:36, 2.00it/s, loss=0.758]
Epoch 2: 29%|##9 | 29/100 [00:14<00:35, 2.00it/s, loss=0.758]
Epoch 2: 29%|##9 | 29/100 [00:15<00:35, 2.00it/s, loss=0.537]
Epoch 2: 30%|### | 30/100 [00:15<00:34, 2.00it/s, loss=0.537]
Epoch 2: 30%|### | 30/100 [00:15<00:34, 2.00it/s, loss=0.723]
Epoch 2: 31%|###1 | 31/100 [00:15<00:34, 2.00it/s, loss=0.723]
Epoch 2: 31%|###1 | 31/100 [00:16<00:34, 2.00it/s, loss=0.583]
Epoch 2: 32%|###2 | 32/100 [00:16<00:33, 2.00it/s, loss=0.583]
Epoch 2: 32%|###2 | 32/100 [00:16<00:33, 2.00it/s, loss=0.692]
Epoch 2: 33%|###3 | 33/100 [00:16<00:33, 2.00it/s, loss=0.692]
Epoch 2: 33%|###3 | 33/100 [00:17<00:33, 2.00it/s, loss=0.937]
Epoch 2: 34%|###4 | 34/100 [00:17<00:33, 2.00it/s, loss=0.937]
Epoch 2: 34%|###4 | 34/100 [00:17<00:33, 2.00it/s, loss=0.743]
Epoch 2: 35%|###5 | 35/100 [00:17<00:32, 1.99it/s, loss=0.743]
Epoch 2: 35%|###5 | 35/100 [00:18<00:32, 1.99it/s, loss=0.822]
Epoch 2: 36%|###6 | 36/100 [00:18<00:32, 2.00it/s, loss=0.822]
Epoch 2: 36%|###6 | 36/100 [00:18<00:32, 2.00it/s, loss=0.655]
Epoch 2: 37%|###7 | 37/100 [00:18<00:31, 1.99it/s, loss=0.655]
Epoch 2: 37%|###7 | 37/100 [00:19<00:31, 1.99it/s, loss=0.735]
Epoch 2: 38%|###8 | 38/100 [00:19<00:31, 2.00it/s, loss=0.735]
Epoch 2: 38%|###8 | 38/100 [00:19<00:31, 2.00it/s, loss=0.805]
Epoch 2: 39%|###9 | 39/100 [00:19<00:30, 1.99it/s, loss=0.805]
Epoch 2: 39%|###9 | 39/100 [00:20<00:30, 1.99it/s, loss=0.795]
Epoch 2: 40%|#### | 40/100 [00:20<00:30, 2.00it/s, loss=0.795]
Epoch 2: 40%|#### | 40/100 [00:20<00:30, 2.00it/s, loss=0.663]
Epoch 2: 41%|####1 | 41/100 [00:20<00:29, 2.00it/s, loss=0.663]
Epoch 2: 41%|####1 | 41/100 [00:21<00:29, 2.00it/s, loss=0.691]
Epoch 2: 42%|####2 | 42/100 [00:21<00:29, 2.00it/s, loss=0.691]
Epoch 2: 42%|####2 | 42/100 [00:21<00:29, 2.00it/s, loss=0.516]
Epoch 2: 43%|####3 | 43/100 [00:21<00:28, 1.99it/s, loss=0.516]
Epoch 2: 43%|####3 | 43/100 [00:22<00:28, 1.99it/s, loss=0.759]
Epoch 2: 44%|####4 | 44/100 [00:22<00:28, 2.00it/s, loss=0.759]
Epoch 2: 44%|####4 | 44/100 [00:22<00:28, 2.00it/s, loss=0.641]
Epoch 2: 45%|####5 | 45/100 [00:22<00:27, 1.99it/s, loss=0.641]
Epoch 2: 45%|####5 | 45/100 [00:23<00:27, 1.99it/s, loss=0.711]
Epoch 2: 46%|####6 | 46/100 [00:23<00:27, 1.98it/s, loss=0.711]
Epoch 2: 46%|####6 | 46/100 [00:23<00:27, 1.98it/s, loss=0.513]
Epoch 2: 47%|####6 | 47/100 [00:23<00:26, 1.98it/s, loss=0.513]
Epoch 2: 47%|####6 | 47/100 [00:24<00:26, 1.98it/s, loss=0.552]
Epoch 2: 48%|####8 | 48/100 [00:24<00:26, 1.99it/s, loss=0.552]
Epoch 2: 48%|####8 | 48/100 [00:24<00:26, 1.99it/s, loss=1.02]
Epoch 2: 49%|####9 | 49/100 [00:24<00:25, 1.99it/s, loss=1.02]
Epoch 2: 49%|####9 | 49/100 [00:25<00:25, 1.99it/s, loss=0.525]
Epoch 2: 50%|##### | 50/100 [00:25<00:25, 1.99it/s, loss=0.525]
Epoch 2: 50%|##### | 50/100 [00:25<00:25, 1.99it/s, loss=0.696]
Epoch 2: 51%|#####1 | 51/100 [00:25<00:24, 1.99it/s, loss=0.696]
Epoch 2: 51%|#####1 | 51/100 [00:26<00:24, 1.99it/s, loss=0.589]
Epoch 2: 52%|#####2 | 52/100 [00:26<00:24, 1.99it/s, loss=0.589]
Epoch 2: 52%|#####2 | 52/100 [00:26<00:24, 1.99it/s, loss=0.57]
Epoch 2: 53%|#####3 | 53/100 [00:26<00:23, 2.00it/s, loss=0.57]
Epoch 2: 53%|#####3 | 53/100 [00:27<00:23, 2.00it/s, loss=1.04]
Epoch 2: 54%|#####4 | 54/100 [00:27<00:23, 2.00it/s, loss=1.04]
Epoch 2: 54%|#####4 | 54/100 [00:27<00:23, 2.00it/s, loss=0.727]
Epoch 2: 55%|#####5 | 55/100 [00:27<00:22, 2.00it/s, loss=0.727]
Epoch 2: 55%|#####5 | 55/100 [00:28<00:22, 2.00it/s, loss=1.17]
Epoch 2: 56%|#####6 | 56/100 [00:28<00:22, 1.99it/s, loss=1.17]
Epoch 2: 56%|#####6 | 56/100 [00:28<00:22, 1.99it/s, loss=0.547]
Epoch 2: 57%|#####6 | 57/100 [00:28<00:21, 1.99it/s, loss=0.547]
Epoch 2: 57%|#####6 | 57/100 [00:29<00:21, 1.99it/s, loss=0.492]
Epoch 2: 58%|#####8 | 58/100 [00:29<00:21, 1.99it/s, loss=0.492]
Epoch 2: 58%|#####8 | 58/100 [00:29<00:21, 1.99it/s, loss=0.748]
Epoch 2: 59%|#####8 | 59/100 [00:29<00:20, 1.99it/s, loss=0.748]
Epoch 2: 59%|#####8 | 59/100 [00:30<00:20, 1.99it/s, loss=0.589]
Epoch 2: 60%|###### | 60/100 [00:30<00:20, 1.97it/s, loss=0.589]
Epoch 2: 60%|###### | 60/100 [00:30<00:20, 1.97it/s, loss=0.661]
Epoch 2: 61%|######1 | 61/100 [00:30<00:19, 1.97it/s, loss=0.661]
Epoch 2: 61%|######1 | 61/100 [00:31<00:19, 1.97it/s, loss=0.787]
Epoch 2: 62%|######2 | 62/100 [00:31<00:19, 1.98it/s, loss=0.787]
Epoch 2: 62%|######2 | 62/100 [00:31<00:19, 1.98it/s, loss=0.827]
Epoch 2: 63%|######3 | 63/100 [00:31<00:18, 1.98it/s, loss=0.827]
Epoch 2: 63%|######3 | 63/100 [00:32<00:18, 1.98it/s, loss=0.443]
Epoch 2: 64%|######4 | 64/100 [00:32<00:18, 1.97it/s, loss=0.443]
Epoch 2: 64%|######4 | 64/100 [00:32<00:18, 1.97it/s, loss=0.483]
Epoch 2: 65%|######5 | 65/100 [00:32<00:17, 1.98it/s, loss=0.483]
Epoch 2: 65%|######5 | 65/100 [00:33<00:17, 1.98it/s, loss=0.628]
Epoch 2: 66%|######6 | 66/100 [00:33<00:17, 1.98it/s, loss=0.628]
Epoch 2: 66%|######6 | 66/100 [00:33<00:17, 1.98it/s, loss=0.795]
Epoch 2: 67%|######7 | 67/100 [00:33<00:16, 1.99it/s, loss=0.795]
Epoch 2: 67%|######7 | 67/100 [00:34<00:16, 1.99it/s, loss=0.726]
Epoch 2: 68%|######8 | 68/100 [00:34<00:16, 2.00it/s, loss=0.726]
Epoch 2: 68%|######8 | 68/100 [00:34<00:16, 2.00it/s, loss=0.631]
Epoch 2: 69%|######9 | 69/100 [00:34<00:15, 2.00it/s, loss=0.631]
Epoch 2: 69%|######9 | 69/100 [00:35<00:15, 2.00it/s, loss=0.564]
Epoch 2: 70%|####### | 70/100 [00:35<00:15, 2.00it/s, loss=0.564]
Epoch 2: 70%|####### | 70/100 [00:35<00:15, 2.00it/s, loss=0.616]
Epoch 2: 71%|#######1 | 71/100 [00:35<00:14, 2.00it/s, loss=0.616]
Epoch 2: 71%|#######1 | 71/100 [00:36<00:14, 2.00it/s, loss=0.695]
Epoch 2: 72%|#######2 | 72/100 [00:36<00:14, 2.00it/s, loss=0.695]
Epoch 2: 72%|#######2 | 72/100 [00:36<00:14, 2.00it/s, loss=0.526]
Epoch 2: 73%|#######3 | 73/100 [00:36<00:13, 2.00it/s, loss=0.526]
Epoch 2: 73%|#######3 | 73/100 [00:37<00:13, 2.00it/s, loss=0.785]
Epoch 2: 74%|#######4 | 74/100 [00:37<00:13, 1.99it/s, loss=0.785]
Epoch 2: 74%|#######4 | 74/100 [00:37<00:13, 1.99it/s, loss=0.611]
Epoch 2: 75%|#######5 | 75/100 [00:37<00:12, 2.00it/s, loss=0.611]
Epoch 2: 75%|#######5 | 75/100 [00:38<00:12, 2.00it/s, loss=0.53]
Epoch 2: 76%|#######6 | 76/100 [00:38<00:12, 2.00it/s, loss=0.53]
Epoch 2: 76%|#######6 | 76/100 [00:38<00:12, 2.00it/s, loss=0.918]
Epoch 2: 77%|#######7 | 77/100 [00:38<00:11, 2.00it/s, loss=0.918]
Epoch 2: 77%|#######7 | 77/100 [00:39<00:11, 2.00it/s, loss=0.577]
Epoch 2: 78%|#######8 | 78/100 [00:39<00:10, 2.00it/s, loss=0.577]
Epoch 2: 78%|#######8 | 78/100 [00:39<00:10, 2.00it/s, loss=0.763]
Epoch 2: 79%|#######9 | 79/100 [00:39<00:10, 2.00it/s, loss=0.763]
Epoch 2: 79%|#######9 | 79/100 [00:40<00:10, 2.00it/s, loss=0.688]
Epoch 2: 80%|######## | 80/100 [00:40<00:09, 2.01it/s, loss=0.688]
Epoch 2: 80%|######## | 80/100 [00:40<00:09, 2.01it/s, loss=0.582]
Epoch 2: 81%|########1 | 81/100 [00:40<00:09, 2.00it/s, loss=0.582]
Epoch 2: 81%|########1 | 81/100 [00:41<00:09, 2.00it/s, loss=0.532]
Epoch 2: 82%|########2 | 82/100 [00:41<00:08, 2.01it/s, loss=0.532]
Epoch 2: 82%|########2 | 82/100 [00:41<00:08, 2.01it/s, loss=0.648]
Epoch 2: 83%|########2 | 83/100 [00:41<00:08, 2.00it/s, loss=0.648]
Epoch 2: 83%|########2 | 83/100 [00:42<00:08, 2.00it/s, loss=0.506]
Epoch 2: 84%|########4 | 84/100 [00:42<00:07, 2.00it/s, loss=0.506]
Epoch 2: 84%|########4 | 84/100 [00:42<00:07, 2.00it/s, loss=0.757]
Epoch 2: 85%|########5 | 85/100 [00:42<00:07, 2.00it/s, loss=0.757]
Epoch 2: 85%|########5 | 85/100 [00:43<00:07, 2.00it/s, loss=0.709]
Epoch 2: 86%|########6 | 86/100 [00:43<00:06, 2.00it/s, loss=0.709]
Epoch 2: 86%|########6 | 86/100 [00:43<00:06, 2.00it/s, loss=0.494]
Epoch 2: 87%|########7 | 87/100 [00:43<00:06, 2.00it/s, loss=0.494]
Epoch 2: 87%|########7 | 87/100 [00:44<00:06, 2.00it/s, loss=0.887]
Epoch 2: 88%|########8 | 88/100 [00:44<00:05, 2.00it/s, loss=0.887]
Epoch 2: 88%|########8 | 88/100 [00:44<00:05, 2.00it/s, loss=0.692]
Epoch 2: 89%|########9 | 89/100 [00:44<00:05, 2.00it/s, loss=0.692]
Epoch 2: 89%|########9 | 89/100 [00:45<00:05, 2.00it/s, loss=0.847]
Epoch 2: 90%|######### | 90/100 [00:45<00:05, 2.00it/s, loss=0.847]
Epoch 2: 90%|######### | 90/100 [00:45<00:05, 2.00it/s, loss=0.563]
Epoch 2: 91%|#########1| 91/100 [00:45<00:04, 2.00it/s, loss=0.563]
Epoch 2: 91%|#########1| 91/100 [00:46<00:04, 2.00it/s, loss=0.437]
Epoch 2: 92%|#########2| 92/100 [00:46<00:04, 1.99it/s, loss=0.437]
Epoch 2: 92%|#########2| 92/100 [00:46<00:04, 1.99it/s, loss=0.593]
Epoch 2: 93%|#########3| 93/100 [00:46<00:03, 1.96it/s, loss=0.593]
Epoch 2: 93%|#########3| 93/100 [00:47<00:03, 1.96it/s, loss=0.648]
Epoch 2: 94%|#########3| 94/100 [00:47<00:03, 1.97it/s, loss=0.648]
Epoch 2: 94%|#########3| 94/100 [00:47<00:03, 1.97it/s, loss=0.428]
Epoch 2: 95%|#########5| 95/100 [00:47<00:02, 1.97it/s, loss=0.428]
Epoch 2: 95%|#########5| 95/100 [00:48<00:02, 1.97it/s, loss=0.693]
Epoch 2: 96%|#########6| 96/100 [00:48<00:02, 1.98it/s, loss=0.693]
Epoch 2: 96%|#########6| 96/100 [00:48<00:02, 1.98it/s, loss=0.468]
Epoch 2: 97%|#########7| 97/100 [00:48<00:01, 1.98it/s, loss=0.468]
Epoch 2: 97%|#########7| 97/100 [00:49<00:01, 1.98it/s, loss=0.71]
Epoch 2: 98%|#########8| 98/100 [00:49<00:01, 1.99it/s, loss=0.71]
Epoch 2: 98%|#########8| 98/100 [00:49<00:01, 1.99it/s, loss=0.546]
Epoch 2: 99%|#########9| 99/100 [00:49<00:00, 2.00it/s, loss=0.546]
Epoch 2: 99%|#########9| 99/100 [00:50<00:00, 2.00it/s, loss=0.715]
Epoch 2: 100%|##########| 100/100 [00:50<00:00, 2.00it/s, loss=0.715]
Epoch 2: 100%|##########| 100/100 [00:50<00:00, 1.99it/s, loss=0.715]
Epoch 2, Loss: 0.7027, Val Accuracy: 0.7625
Once the training has finished, we can eyeball the results using a simple example. The example uses a clear semantic relationship between sleeping and resting to showcase the model’s capability to recognize entailment.
= "The cat is sleeping"
premise = "The cat is resting."
hypothesis = predict_nli(model, tokenizer, premise, hypothesis)
prediction print(f"Prediction: {prediction}")
Prediction: entailment
In our final evaluation chunk, we perform comprehensive testing on our held-out test set. We process all test examples through our model and calculate various performance metrics using scikit-learn
’s classification report. This gives us a detailed view of the model’s performance across different relationship types, including precision, recall, and F1-score for each class.
= []
predictions for idx, row in test.iterrows():
= predict_nli(model, tokenizer, row['premise'], row['hypothesis'])
pred
predictions.append(pred)
'predicted'] = predictions
test[
print(classification_report(test['label'], test['predicted']))
precision recall f1-score support
contradiction 0.76 0.69 0.72 328
entailment 0.77 0.86 0.81 342
neutral 0.66 0.65 0.65 330
accuracy 0.73 1000
macro avg 0.73 0.73 0.73 1000
weighted avg 0.73 0.73 0.73 1000