library(torch)
library(torchtransformers)
library(luz)
library(dlr)
library(dplyr)
library(yardstick)
Textual entailment is a common NLP task, and is included in the GLUE and SuperGLUE NLP benchmarks.
The task consists of two pieces of text, a premise
and a
hypothesis
. For example, this is a premise/hypothesis pair
from the MultiNLI
dataset (MNLI, described in more detail below):
In this case, the premise entails the hypothesis. This means that the hypothesis follows from the premise.
In contrast, this is another premise/hypothesis pair from the same dataset:
In this case, the premise contradicts the hypothesis. The premise lists things that are displayed in the museum, while the hypothesis asserts that the museum is empty.
Finally, this is another premise/hypothesis pair from MNLI:
While Horus was an Egyptian god, the premise doesn’t mention that, so the premise neither entails nor contradicts the hypothesis. This pair is said to be neutral.
In this vignette, we’ll use the MNLI dataset to fine-tune a BERT model for an entailment task.
The Multi-Genre Natural Language Inference (MultiNLI or MNLI) corpus was described in A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference (Williams et al., NAACL 2018). It includes 433k premise-hypothesis pairs, annotated with entailment information. The premises are divided into 10 genres. Five of the genres (“fiction”, “government”, “slate”, “telephone”, and “travel”) are included in the training dataset, and the other five genres (“facetoface”, “letters”, “nineeleven”, “oup”, and “verbatim”) are not.
The data are subdivided into five datasets:
The test sets are for scoring your model on Kaggle, so we’ll skip those.
We’ll train our model using train.tsv
, and test it using
dev_matched.tsv
and dev_mismatched.tsv
.
# Set up a processor function for {dlr} to load the data.
process_mnli <- function(source_file) {
dataset_names <- c(
"train",
"dev_matched",
"dev_mismatched",
"test_matched",
"test_mismatched"
)
# Also make those the names so purrr uses them.
names(dataset_names) <- dataset_names
mnli_tibbles <- purrr::map(
dataset_names,
function(this_dataset) {
# We specify column types to make sure things come in as we expect.
column_spec <- dplyr::case_when(
stringr::str_starts(this_dataset, "dev_") ~ "iicccccccccccccc",
stringr::str_starts(this_dataset, "test_") ~ "iiiccccccc",
TRUE ~ "iicccccccccc"
)
raw_tibble <- readr::read_tsv(
unz(source_file, fs::path("MNLI", this_dataset, ext = "tsv")),
col_types = column_spec,
# There are a couple lines that screw up if we include a quote
# character.
quote = ""
)
# If there are labels, standardize them, to make sure the factor levels
# are always the same.
if ("gold_label" %in% colnames(raw_tibble)) {
raw_tibble$gold_label <- factor(
raw_tibble$gold_label,
levels = c("entailment", "neutral", "contradiction")
)
}
return(
dplyr::select(
raw_tibble,
-index,
-promptID,
-pairID,
-dplyr::ends_with("_parse"),
-dplyr::starts_with("label")
)
)
}
)
return(mnli_tibbles)
}
# By default downloading large files often fails. Increase the timeout.
old_timeout <- options(timeout = 1000)
data_url <- "https://dl.fbaipublicfiles.com/glue/data/MNLI.zip"
mnli_tibbles <- dlr::read_or_cache(
source_path = data_url,
appname = "torchtransformers",
process_f = process_mnli
)
# Restore the timeout.
options(old_timeout)
We need to set these datasets up for use with {luz}. We can use
dataset_bert_pretrained()
to process the train, matched,
and mismatched datasets.
train_ds <- dataset_bert_pretrained(
x = dplyr::select(
mnli_tibbles$train,
sentence1,
sentence2
),
y = mnli_tibbles$train$gold_label
)
test_matched_ds <- dataset_bert_pretrained(
x = dplyr::select(
mnli_tibbles$dev_matched,
sentence1,
sentence2
),
y = mnli_tibbles$dev_matched$gold_label
)
test_mismatched_ds <- dataset_bert_pretrained(
x = dplyr::select(
mnli_tibbles$dev_mismatched,
sentence1,
sentence2
),
y = mnli_tibbles$dev_mismatched$gold_label
)
Note that we do not tokenize the data at this point. We’ll let the model trigger tokenization to make sure the data is in the format the model expects.
We’ll construct a model based on BERT, with a linear layer to score the input on the three label dimensions.
entailment_classifier <- torch::nn_module(
"entailment_classifier",
initialize = function(bert_type = "bert_tiny_uncased") {
embedding_size <- config_bert(bert_type, "embedding_size")
self$bert <- model_bert_pretrained(bert_type)
# After pooled bert output, do a final dense layer.
self$linear <- torch::nn_linear(
in_features = embedding_size,
out_features = 3L # 3 possible labels
)
},
forward = function(x) {
output <- self$bert(x)
# Take the output embeddings from the last layer.
output <- output$output_embeddings
output <- output[[length(output)]]
# Take the [CLS] token embedding for classification.
output <- output[ , 1, ]
# Apply the last dense layer to the pooled output.
output <- self$linear(output)
return(output)
}
)
We fit the model using {luz}
. We only fit for one epoch
as a proof of concept.
torch::torch_manual_seed(123456)
fitted <- entailment_classifier %>%
luz::setup(
loss = torch::nn_cross_entropy_loss(),
optimizer = torch::optim_adam,
metrics = list(
luz::luz_metric_accuracy()
)
) %>%
fit(
train_ds,
epochs = 1,
callbacks = list(
luz_callback_bert_tokenize(
submodel_name = "bert",
n_tokens = 128L # We don't want the full 512 for this example.
)
),
valid_data = 0.1,
dataloader_options = list(batch_size = 256L)
)
We predict the two test datasets, and measure the results.
predictions_matched <- fitted %>%
predict(
test_matched_ds,
callbacks = list(
luz_callback_bert_tokenize(
submodel_name = "bert",
n_tokens = 128L
)
)
) %>%
torch::nnf_softmax(2) %>%
torch::torch_argmax(2)
predictions_matched <- predictions_matched$to(device = "cpu") %>%
torch::as_array()
dev_matched <- mnli_tibbles$dev_matched %>%
dplyr::mutate(
.pred = factor(
predictions_matched, levels = 1:3, labels = levels(gold_label)
)
)
yardstick::accuracy(dev_matched, gold_label, .pred)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.659
predictions_mismatched <- fitted %>%
predict(
test_mismatched_ds,
callbacks = list(
luz_callback_bert_tokenize(
submodel_name = "bert",
n_tokens = 128L
)
)
) %>%
torch::nnf_softmax(2) %>%
torch::torch_argmax(2)
predictions_mismatched <- predictions_mismatched$to(device = "cpu") %>%
torch::as_array()
dev_mismatched <- mnli_tibbles$dev_mismatched %>%
dplyr::mutate(
.pred = factor(
predictions_mismatched, levels = 1:3, labels = levels(gold_label)
)
)
yardstick::accuracy(dev_mismatched, gold_label, .pred)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.661
The published results for bert_tiny_uncased on these datasets are 0.72 and 0.73, so our results of 0.66 and 0.66 after a single epoch of fine-tuning are on track.