BERT Classifier using the Trainer API
In the previous example we wrote PyTorch code directly. Here we modify that to use the Training API, because it is simpler.
Entire Code and Code explained
Below is the entire code. Below that we explain certain sections.
It is simpler because you only need to put the parameters for the model in from transformers import TrainingArguments, Trainer
. Then the training code just looks like these few lines:
args = TrainingArgumentsWithMPSSupport(
output_dir="output",
evaluation_strategy="steps",
eval_steps=500,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
seed=0,
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()
Training Time and Hardware Requirements
As with the other example, this code is modified to use the mps on a MAC, which is the Apple interface to the gpu. We also only trained on a few records as training an entire Law Insider repository would take too much time for a simple laptop. Law Insider uses a whole lot of cloud servers and GPUs to do that.
Program Flow
The steps are to:
- Read an Avro data file downloaded from LawInsider Private Repository. LawInsider support will copy that to S3 or Azure for you. This data has been labelled as contracts or with other labels that are not contracts, like agreements.
- The data is a label-feature dataset. We changed that text to the numbers 1 for contracts and 0 for any other type of document.
- Read the Avro data into a list, then create a Pandas DataFrame and give it the column names category and text.
- Use the
BertTokenizer bert-base-cased
to tokenize each sentence into encoded tokens. - Split the data into training and test datasets.
- Run the training model. Compute metrics.
- Read the same data and then run predictions on it. Print out the predictions.
Complete Code
import avro.schema
from avro.datafile import DataFileReader, DataFileWriter
from avro.io import DatumReader, DatumWriter
from bs4 import BeautifulSoup
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import torch
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import EarlyStoppingCallback
from transformers import TrainingArguments, Trainer
class TrainingArgumentsWithMPSSupport(TrainingArguments):
@property
def device(self) -> torch.device:
return torch.device("mps")
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
def html_to_text(html):
soup = BeautifulSoup(html, 'html.parser')
text = soup.get_text()
return text.strip()
avro_reader = DataFileReader(open(<avro_file>, "rb"), DatumReader())
labels_with_text = []
for document in avro_reader:
document_text = html_to_text(document['body'])
if document['labels'].get('contract'):
labels_with_text.append((1, document_text))
else:
labels_with_text.append((0, document_text))
data=pd.DataFrame(labels_with_text, columns=["label","text"])
model_name = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
X = list(data["text"])
y = list(data["label"])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
def compute_metrics(p):
pred, labels = p
pred = np.argmax(pred, axis=1)
accuracy = accuracy_score(y_true=labels, y_pred=pred)
recall = recall_score(y_true=labels, y_pred=pred)
precision = precision_score(y_true=labels, y_pred=pred)
f1 = f1_score(y_true=labels, y_pred=pred)
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
args = TrainingArgumentsWithMPSSupport(
output_dir="output",
evaluation_strategy="steps",
eval_steps=500,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
seed=0,
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()
test_data = data
X_test = list(test_data["text"])
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)
test_dataset = Dataset(X_test_tokenized)
test_trainer = Trainer(model)
raw_pred, _, _ = test_trainer.predict(test_dataset)
y_pred = np.argmax(raw_pred, axis=1)
y_pred
Code Sections Explained
Here we override TrainingArguments
so that we can change the torch.device
on a Mac to the mps. This is because we cannot pass the device as an argument in TrainingArguments
.
class TrainingArgumentsWithMPSSupport(TrainingArguments):
@property
def device(self) -> torch.device:
return torch.device("mps")
Trainer requires a PyTorch DataSet. As you can see, the DataSet class stores encoded values of the text and labels.
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
Here we read the Avro file into an iterable which we then loop through adding each record to a list.
avro_reader = DataFileReader(open(<avro_file>, "rb"), DatumReader())
labels_with_text = []
for document in avro_reader:
document_text = html_to_text(document['body'])
if document['labels'].get('contract'):
labels_with_text.append((1, document_text))
else:
labels_with_text.append((0, document_text))
Since the text document['body']
includes HTML, we strip out that using:
def html_to_text(html):
soup = BeautifulSoup(html, 'html.parser')
text = soup.get_text()
return text.strip()
The Avro data file has already been labelled, since it is label-feature data. We turn those labels into integers: 1 (contract) and 0 (not a contract).
for document in avro_reader:
document_text = html_to_text(document['body'])
if document['labels'].get('contract'):
labels_with_text.append((1, document_text))
else:
labels_with_text.append((0, document_text))
Then we make that a Pandas DataFrame:
data=pd.DataFrame(labels_with_text, columns=["label","text"])
We declare the model and tokenizer. We use a pre-trained BERT model. For background on that read this. The tokens must be in the same format as the trained model, which is why we use BertTokenizer.from_pretrained(model_name)
.
The model is bert-base-cased.
model_name = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
We split then tokenize the training and test data then create the DataSets.
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
We define a metrics function since the Trainer
requires that. This example uses the sklearn.metrics
package.
def compute_metrics(p):
pred, labels = p
pred = np.argmax(pred, axis=1)
accuracy = accuracy_score(y_true=labels, y_pred=pred)
recall = recall_score(y_true=labels, y_pred=pred)
precision = precision_score(y_true=labels, y_pred=pred)
f1 = f1_score(y_true=labels, y_pred=pred)
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
Note that TrainingArguments
is optional, as it is only used to override default model parameter values.
args = TrainingArgumentsWithMPSSupport(
output_dir="output",
evaluation_strategy="steps",
eval_steps=500,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
seed=0,
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()
Then we make some predictions. We feed it text, without the label, since the label is what we are predicting. We just use the same input, although we could made other data, like another contract.
test_data = data
X_test = list(test_data["text"])
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)
test_dataset = Dataset(X_test_tokenized)
test_trainer = Trainer(model)
raw_pred, _, _ = test_trainer.predict(test_dataset)
y_pred = np.argmax(raw_pred, axis=1)
y_pred
Reference
Fine-tuning pretrained NLP models with Huggingface’s Trainer
Updated over 1 year ago