--- title: "Textual Entailment" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Textual Entailment} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = FALSE ) ``` ```{r setup} library(torch) library(torchtransformers) library(luz) library(dlr) library(dplyr) library(yardstick) ``` Textual entailment is a common NLP task, and is included in the [GLUE](https://gluebenchmark.com/) and [SuperGLUE](https://super.gluebenchmark.com/) NLP benchmarks. The task consists of two pieces of text, a `premise` and a `hypothesis`. For example, this is a premise/hypothesis pair from the [MultiNLI dataset](https://cims.nyu.edu/~sbowman/multinli/) (MNLI, described in more detail below): - **Premise:** "Here you'll see a shrunken head, a two-headed goat, and a statue of Marilyn Monroe made of shredded money, among other curiosities." - **Hypothesis:** "One of the curiosities is a two-headed goat." In this case, the premise *entails* the hypothesis. This means that the hypothesis follows from the premise. In contrast, this is another premise/hypothesis pair from the same dataset: - **Premise:** "There is also an archaeological museum that displays older relics, including examples of Mycenaean pottery." - **Hypothesis:** "The museum is completely empty and doesn't have anything in it." In this case, the premise *contradicts* the hypothesis. The premise lists things that are displayed in the museum, while the hypothesis asserts that the museum is empty. Finally, this is another premise/hypothesis pair from MNLI: - **Premise::** "At the heart of the sanctuary, a small granite shrine once held the sacred barque of Horus himself." - **Hypothesis:** "Horus is a god." While Horus was an Egyptian god, the premise doesn't mention that, so the premise neither entails nor contradicts the hypothesis. This pair is said to be **neutral.** In this vignette, we'll use the MNLI dataset to fine-tune a BERT model for an entailment task. ## The MNLI Dataset The Multi-Genre Natural Language Inference (MultiNLI or MNLI) corpus was described in [A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference](https://aclanthology.org/N18-1101) (Williams et al., NAACL 2018). It includes 433k premise-hypothesis pairs, annotated with entailment information. The premises are divided into 10 genres. Five of the genres ("fiction", "government", "slate", "telephone", and "travel") are included in the training dataset, and the other five genres ("facetoface", "letters", "nineeleven", "oup", and "verbatim") are not. The data are subdivided into five datasets: - train.tsv (392,702 observations from five genres) - dev_matched.tsv (9,815 observations from the five training genres) - dev_mismatched.tsv (9,832 observations from the other five genres) - test_matched.tsv (9,796 observations from the five training genres, no labels) - test_mismatched.tsv (9,847 observations from the other five genres, no labels) The test sets are for scoring your model on Kaggle, so we'll skip those. We'll train our model using `train.tsv`, and test it using `dev_matched.tsv` and `dev_mismatched.tsv`. ```{r download} # Set up a processor function for {dlr} to load the data. process_mnli <- function(source_file) { dataset_names <- c( "train", "dev_matched", "dev_mismatched", "test_matched", "test_mismatched" ) # Also make those the names so purrr uses them. names(dataset_names) <- dataset_names mnli_tibbles <- purrr::map( dataset_names, function(this_dataset) { # We specify column types to make sure things come in as we expect. column_spec <- dplyr::case_when( stringr::str_starts(this_dataset, "dev_") ~ "iicccccccccccccc", stringr::str_starts(this_dataset, "test_") ~ "iiiccccccc", TRUE ~ "iicccccccccc" ) raw_tibble <- readr::read_tsv( unz(source_file, fs::path("MNLI", this_dataset, ext = "tsv")), col_types = column_spec, # There are a couple lines that screw up if we include a quote # character. quote = "" ) # If there are labels, standardize them, to make sure the factor levels # are always the same. if ("gold_label" %in% colnames(raw_tibble)) { raw_tibble$gold_label <- factor( raw_tibble$gold_label, levels = c("entailment", "neutral", "contradiction") ) } return( dplyr::select( raw_tibble, -index, -promptID, -pairID, -dplyr::ends_with("_parse"), -dplyr::starts_with("label") ) ) } ) return(mnli_tibbles) } # By default downloading large files often fails. Increase the timeout. old_timeout <- options(timeout = 1000) data_url <- "https://dl.fbaipublicfiles.com/glue/data/MNLI.zip" mnli_tibbles <- dlr::read_or_cache( source_path = data_url, appname = "torchtransformers", process_f = process_mnli ) # Restore the timeout. options(old_timeout) ``` We need to set these datasets up for use with {luz}. We can use `dataset_bert_pretrained()` to process the train, matched, and mismatched datasets. ```{r datasets} train_ds <- dataset_bert_pretrained( x = dplyr::select( mnli_tibbles$train, sentence1, sentence2 ), y = mnli_tibbles$train$gold_label ) test_matched_ds <- dataset_bert_pretrained( x = dplyr::select( mnli_tibbles$dev_matched, sentence1, sentence2 ), y = mnli_tibbles$dev_matched$gold_label ) test_mismatched_ds <- dataset_bert_pretrained( x = dplyr::select( mnli_tibbles$dev_mismatched, sentence1, sentence2 ), y = mnli_tibbles$dev_mismatched$gold_label ) ``` Note that we do *not* tokenize the data at this point. We'll let the model trigger tokenization to make sure the data is in the format the model expects. ## Model We'll construct a model based on BERT, with a linear layer to score the input on the three label dimensions. ```{r model} entailment_classifier <- torch::nn_module( "entailment_classifier", initialize = function(bert_type = "bert_tiny_uncased") { embedding_size <- config_bert(bert_type, "embedding_size") self$bert <- model_bert_pretrained(bert_type) # After pooled bert output, do a final dense layer. self$linear <- torch::nn_linear( in_features = embedding_size, out_features = 3L # 3 possible labels ) }, forward = function(x) { output <- self$bert(x) # Take the output embeddings from the last layer. output <- output$output_embeddings output <- output[[length(output)]] # Take the [CLS] token embedding for classification. output <- output[ , 1, ] # Apply the last dense layer to the pooled output. output <- self$linear(output) return(output) } ) ``` We fit the model using `{luz}`. We only fit for one epoch as a proof of concept. ```{r fit} torch::torch_manual_seed(123456) fitted <- entailment_classifier %>% luz::setup( loss = torch::nn_cross_entropy_loss(), optimizer = torch::optim_adam, metrics = list( luz::luz_metric_accuracy() ) ) %>% fit( train_ds, epochs = 1, callbacks = list( luz_callback_bert_tokenize( submodel_name = "bert", n_tokens = 128L # We don't want the full 512 for this example. ) ), valid_data = 0.1, dataloader_options = list(batch_size = 256L) ) ``` ## Results We predict the two test datasets, and measure the results. ```{r predict} predictions_matched <- fitted %>% predict( test_matched_ds, callbacks = list( luz_callback_bert_tokenize( submodel_name = "bert", n_tokens = 128L ) ) ) %>% torch::nnf_softmax(2) %>% torch::torch_argmax(2) predictions_matched <- predictions_matched$to(device = "cpu") %>% torch::as_array() dev_matched <- mnli_tibbles$dev_matched %>% dplyr::mutate( .pred = factor( predictions_matched, levels = 1:3, labels = levels(gold_label) ) ) yardstick::accuracy(dev_matched, gold_label, .pred) #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> #> 1 accuracy multiclass 0.659 predictions_mismatched <- fitted %>% predict( test_mismatched_ds, callbacks = list( luz_callback_bert_tokenize( submodel_name = "bert", n_tokens = 128L ) ) ) %>% torch::nnf_softmax(2) %>% torch::torch_argmax(2) predictions_mismatched <- predictions_mismatched$to(device = "cpu") %>% torch::as_array() dev_mismatched <- mnli_tibbles$dev_mismatched %>% dplyr::mutate( .pred = factor( predictions_mismatched, levels = 1:3, labels = levels(gold_label) ) ) yardstick::accuracy(dev_mismatched, gold_label, .pred) #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> #> 1 accuracy multiclass 0.661 ``` The published results for bert_tiny_uncased on these datasets are 0.72 and 0.73, so our results of 0.66 and 0.66 after a single epoch of fine-tuning are on track.