Title: | Transformer Models in Torch |
---|---|
Description: | Work with transformer models in R using torch. |
Authors: | Jonathan Bratt [aut, cre] , Jon Harmon [aut] , Bedford Freeman & Worth Pub Grp LLC DBA Macmillan Learning [cph] |
Maintainer: | Jonathan Bratt <[email protected]> |
License: | Apache License (>= 2) |
Version: | 0.0.0.9600 |
Built: | 2024-10-31 21:09:39 UTC |
Source: | https://github.com/macmillancontentscience/torchtransformers |
Takes in an input tensor (e.g. sequence of token embeddings), applies an attention layer and layer-norms the result. Returns both the attention weights and the output embeddings.
attention_bert(embedding_size, n_head, attention_dropout = 0.1)
attention_bert(embedding_size, n_head, attention_dropout = 0.1)
embedding_size |
Integer; the dimension of the embedding vectors. |
n_head |
Integer; the number of attention heads per layer. |
attention_dropout |
Numeric; the dropout probability to apply in attention. |
Inputs:
input:
optional mask:
Output:
embeddings:
weights:
emb_size <- 4L seq_len <- 3L n_head <- 2L batch_size <- 2L model <- attention_bert( embedding_size = emb_size, n_head = n_head ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)
emb_size <- 4L seq_len <- 3L n_head <- 2L batch_size <- 2L model <- attention_bert( embedding_size = emb_size, n_head = n_head ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)
List the BERT models that are defined for this package.
available_berts()
available_berts()
Note that some of the models listed here are actually repeats, listed under
different names. For example, "bert_L2H128_uncased" and "bert_tiny_uncased"
point to the same underlying weights. In general, models with the same values
of hyperparameters (accessed by config_bert
) are identical. However, there
is one exception to this: the "bert_base_uncased" and "bert_L12H768_uncased"
models have the same hyperparameters and training regime, but are actually
distinct models with different actual weights. Any differences between the
models are presumably attributable to different random seeds.
A character vector of BERT types.
available_berts()
available_berts()
Several parameters define a BERT model. This function can be used to easily load them.
config_bert( bert_type, parameter = c("embedding_size", "n_layer", "n_head", "max_tokens", "vocab_size", "tokenizer_scheme") )
config_bert( bert_type, parameter = c("embedding_size", "n_layer", "n_head", "max_tokens", "vocab_size", "tokenizer_scheme") )
bert_type |
Character scalar; the name of a known BERT model. |
parameter |
Character scalar; the desired parameter. |
Integer scalar; the value of that parameter for that model.
config_bert("bert_medium_uncased", "n_head")
config_bert("bert_medium_uncased", "n_head")
Prepare a dataset for BERT-like models.
dataset_bert(x, y = NULL, tokenizer = tokenize_bert, n_tokens = 128L)
dataset_bert(x, y = NULL, tokenizer = tokenize_bert, n_tokens = 128L)
x |
A data.frame with one or more character predictor columns. |
y |
A factor of outcomes, or a data.frame with a single factor column. Can be NULL (default). |
tokenizer |
A tokenization function (signature compatible with
|
n_tokens |
Integer scalar; the number of tokens expected for each example. |
An initialized torch::dataset()
.
initialize
Initialize this dataset. This method is called when the dataset is first created.
.getitem
Fetch an individual predictor (and, if available, the
associated outcome). This function is called automatically by {luz}
during the fitting process.
.length
Determine the length of the dataset (the number of rows of
predictors). Generally superseded by instead calling length()
.
Prepare a dataset for pretrained BERT models.
dataset_bert_pretrained( x, y = NULL, bert_type = NULL, tokenizer_scheme = NULL, n_tokens = NULL )
dataset_bert_pretrained( x, y = NULL, bert_type = NULL, tokenizer_scheme = NULL, n_tokens = NULL )
x |
A data.frame with one or more character predictor columns, or a list, matrix, or character vector that can be coerced to such a data.frame. |
y |
A factor of outcomes, or a data.frame with a single factor column. Can be NULL (default). |
bert_type |
A bert_type from |
tokenizer_scheme |
A character scalar that indicates vocabulary + tokenizer. |
n_tokens |
An integer scalar indicating the number of tokens in the output. |
An initialized torch::dataset()
. If it is not yet tokenized, the
tokenize()
method must be called before the dataset will be usable.
input_data
(private)
The input predictors (x
) standardized to
a data.frame of character columns, and outcome (y
) standardized to a
factor or NULL
.
tokenizer_metadata
(private)
A list indicating the
tokenizer_scheme
and n_tokens
that have been or will be used to
tokenize the predictors (x
).
tokenized
(private)
A single logical value indicating whether
the data has been tokenized.
initialize
Initialize this dataset. This method is called when the dataset is first created.
tokenize
Tokenize this dataset.
untokenize
Remove any tokenization from this dataset.
.tokenize_for_model
Tokenize this dataset for a particular model.
Generally superseded by instead calling luz_callback_bert_tokenize()
.
.getitem
Fetch an individual predictor (and, if available, the
associated outcome). Generally superseded by instead calling .getbatch()
(or by letting the luz modeling process fit automatically).
.getbatch
Fetch specific predictors (and, if available, the
associated outcomes). This function is called automatically by {luz}
during the fitting process.
.length
Determine the length of the dataset (the number of rows of
predictors). Generally superseded by instead calling length()
.
There are three components which are added together to give the input embeddings in a BERT model: the embedding of the tokens themselves, the segment ("token type") embedding, and the position (token index) embedding. This function sets up the embedding layer for all three of these.
embeddings_bert( embedding_size, max_position_embeddings, vocab_size, token_type_vocab_size = 2L, hidden_dropout = 0.1 )
embeddings_bert( embedding_size, max_position_embeddings, vocab_size, token_type_vocab_size = 2L, hidden_dropout = 0.1 )
embedding_size |
Integer; the dimension of the embedding vectors. |
max_position_embeddings |
Integer; maximum number of tokens in each input sequence. |
vocab_size |
Integer; number of tokens in vocabulary. |
token_type_vocab_size |
Integer; number of input segments that the model will recognize. (Two for BERT models.) |
Numeric; the dropout probability to apply to dense layers. |
With sequence_length
<= max_position_embeddings
:
Inputs:
token_ids:
token_type_ids:
Output:
emb_size <- 3L mpe <- 5L vs <- 7L n_inputs <- 2L # get random "ids" for input t_ids <- matrix(sample(2:vs, size = mpe * n_inputs, replace = TRUE), nrow = n_inputs, ncol = mpe ) ttype_ids <- matrix(rep(1L, mpe * n_inputs), nrow = n_inputs, ncol = mpe) model <- embeddings_bert( embedding_size = emb_size, max_position_embeddings = mpe, vocab_size = vs ) model( torch::torch_tensor(t_ids), torch::torch_tensor(ttype_ids) )
emb_size <- 3L mpe <- 5L vs <- 7L n_inputs <- 2L # get random "ids" for input t_ids <- matrix(sample(2:vs, size = mpe * n_inputs, replace = TRUE), nrow = n_inputs, ncol = mpe ) ttype_ids <- matrix(rep(1L, mpe * n_inputs), nrow = n_inputs, ncol = mpe) model <- embeddings_bert( embedding_size = emb_size, max_position_embeddings = mpe, vocab_size = vs ) model( torch::torch_tensor(t_ids), torch::torch_tensor(ttype_ids) )
The torch R package uses the R standard of starting counts at 1. Many tokenizers use the Python standard of starting counts at 0. This function converts a list of token ids provided by such a tokenizer to torch-friendly values (by adding 1 to each id).
increment_list_index(list_of_integers)
increment_list_index(list_of_integers)
The list of integers, with 1 added to each integer.
increment_list_index( list( 1:5, 2:6, 3:7 ) )
increment_list_index( list( 1:5, 2:6, 3:7 ) )
Data used in pretrained BERT models must be tokenized in the way the model
expects. This luz_callback
checks that the incoming data is tokenized
properly, and triggers tokenization if necessary. This function should be
passed to luz::fit.luz_module_generator()
or
luz::predict.luz_module_fitted()
via the callbacks
argument, not called
directly.
luz_callback_bert_tokenize( submodel_name = NULL, n_tokens = NULL, verbose = TRUE )
luz_callback_bert_tokenize( submodel_name = NULL, n_tokens = NULL, verbose = TRUE )
submodel_name |
An optional character scalar identifying a model inside
the main |
n_tokens |
An optional integer scalar indicating the number of tokens to
which the data should be tokenized. If present it must be equal to or less
than the |
verbose |
A logical scalar indicating whether the callback should report
its progress (default |
if (rlang::is_installed("luz")) { luz_callback_bert_tokenize() luz_callback_bert_tokenize(n_tokens = 32L) }
if (rlang::is_installed("luz")) { luz_callback_bert_tokenize() luz_callback_bert_tokenize(n_tokens = 32L) }
BERT models are the family of transformer models popularized by Google's BERT (Bidirectional Encoder Representations from Transformers). They include any model with the same general structure.
model_bert( embedding_size, intermediate_size = 4 * embedding_size, n_layer, n_head, hidden_dropout = 0.1, attention_dropout = 0.1, max_position_embeddings, vocab_size, token_type_vocab_size = 2L )
model_bert( embedding_size, intermediate_size = 4 * embedding_size, n_layer, n_head, hidden_dropout = 0.1, attention_dropout = 0.1, max_position_embeddings, vocab_size, token_type_vocab_size = 2L )
embedding_size |
Integer; the dimension of the embedding vectors. |
intermediate_size |
Integer; size of dense layers applied after attention mechanism. |
n_layer |
Integer; the number of attention layers. |
n_head |
Integer; the number of attention heads per layer. |
Numeric; the dropout probability to apply to dense layers. |
|
attention_dropout |
Numeric; the dropout probability to apply in attention. |
max_position_embeddings |
Integer; maximum number of tokens in each input sequence. |
vocab_size |
Integer; number of tokens in vocabulary. |
token_type_vocab_size |
Integer; number of input segments that the model will recognize. (Two for BERT models.) |
Inputs:
With sequence_length
<= max_position_embeddings
:
token_ids:
token_type_ids:
Output:
initial_embeddings:
output_embeddings: list of for
each transformer layer.
attention_weights: list of for each transformer layer.
emb_size <- 128L mpe <- 512L n_head <- 4L n_layer <- 6L vocab_size <- 30522L model <- model_bert( embedding_size = emb_size, n_layer = n_layer, n_head = n_head, max_position_embeddings = mpe, vocab_size = vocab_size ) n_inputs <- 2 n_token_max <- 128L # get random "ids" for input t_ids <- matrix( sample( 2:vocab_size, size = n_token_max * n_inputs, replace = TRUE ), nrow = n_inputs, ncol = n_token_max ) ttype_ids <- matrix( rep(1L, n_token_max * n_inputs), nrow = n_inputs, ncol = n_token_max ) model( torch::torch_tensor(t_ids), torch::torch_tensor(ttype_ids) )
emb_size <- 128L mpe <- 512L n_head <- 4L n_layer <- 6L vocab_size <- 30522L model <- model_bert( embedding_size = emb_size, n_layer = n_layer, n_head = n_head, max_position_embeddings = mpe, vocab_size = vocab_size ) n_inputs <- 2 n_token_max <- 128L # get random "ids" for input t_ids <- matrix( sample( 2:vocab_size, size = n_token_max * n_inputs, replace = TRUE ), nrow = n_inputs, ncol = n_token_max ) ttype_ids <- matrix( rep(1L, n_token_max * n_inputs), nrow = n_inputs, ncol = n_token_max ) model( torch::torch_tensor(t_ids), torch::torch_tensor(ttype_ids) )
Construct a BERT model (using model_bert()
) and load pretrained weights.
model_bert_pretrained(bert_type = "bert_tiny_uncased", redownload = FALSE)
model_bert_pretrained(bert_type = "bert_tiny_uncased", redownload = FALSE)
bert_type |
Character; which flavor of BERT to use. See
|
redownload |
Logical; should the weights be downloaded fresh even if they're cached? |
The model with pretrained weights loaded.
initialize
Initialize this model. This method is called when the model is first created.
forward
Use this model. This method is called during training, and
also during prediction. x
is a list of torch::torch_tensor()
values for
token_ids
and token_type_ids
.
.get_tokenizer_metadata
Look up the tokenizer metadata for this
model. This method is called automatically when
luz_callback_bert_tokenize()
validates that a dataset is tokenized
properly for this model.
.load_weights
Load the pretrained weights for this model. This method is called automatically during initialization of this model.
Position embeddings are how BERT-like language models represent the order of input tokens. Each token gets a position embedding vector which is completely determined by its position index. Because these embeddings don't depend on the actual input, it is implemented by simply initializing a matrix of weights.
position_embedding(embedding_size, max_position_embeddings)
position_embedding(embedding_size, max_position_embeddings)
embedding_size |
Integer; the dimension of the embedding vectors. |
max_position_embeddings |
Integer; maximum number of tokens in each input sequence. |
Inputs:
No input tensors. Optional input parameter to limit number of positions (tokens) considered.
Output:
emb_size <- 3L mpe <- 2L model <- position_embedding( embedding_size = emb_size, max_position_embeddings = mpe ) model(seq_len_cap = 1) model()
emb_size <- 3L mpe <- 2L model <- position_embedding( embedding_size = emb_size, max_position_embeddings = mpe ) model(seq_len_cap = 1) model()
Takes in two tensors, an "input" and a "residual". Applies a linear projector to the input (changing the size to match residual), performs dropout, adds the result to the residual, then applies layer normalization to the sum.
proj_add_norm(input_size, output_size, hidden_dropout = 0.1)
proj_add_norm(input_size, output_size, hidden_dropout = 0.1)
input_size |
Integer; the size of input tensor. |
output_size |
Integer; the size of output tensor (must match residual). |
Numeric; dropout probability applied after projection. |
Inputs:
input:
residual:
Output:
in_size <- 4L out_size <- 3L model <- proj_add_norm(input_size = in_size, output_size = out_size) input <- torch::torch_randn(in_size) residual <- torch::torch_randn(out_size) model(input, residual)
in_size <- 4L out_size <- 3L model <- proj_add_norm(input_size = in_size, output_size = out_size) input <- torch::torch_randn(in_size) residual <- torch::torch_randn(out_size) model(input, residual)
BERT-like models expect a matrix of tokens for each example. This function converts a list of equal-length vectors (such as a padded list of tokens) into such a matrix.
simplify_bert_token_list(token_list)
simplify_bert_token_list(token_list)
token_list |
A list of vectors. Each vector should have the same length. |
A matrix of tokens. Rows are text sequences, and columns are tokens.
simplify_bert_token_list( list( 1:5, 2:6, 3:7 ) )
simplify_bert_token_list( list( 1:5, 2:6, 3:7 ) )
To be used in a BERT-style model, text must be tokenized. In addition, text
is optionally preceded by a cls_token
, and segments are ended with a
sep_token
. Finally each example must be padded with a
pad_token
, or truncated if necessary (preserving the wrapper tokens).
Many use cases use a matrix of tokens x examples, which can be extracted
directly with the simplify
argument.
tokenize_bert( ..., n_tokens = 64L, increment_index = TRUE, pad_token = "[PAD]", cls_token = "[CLS]", sep_token = "[SEP]", tokenizer = wordpiece::wordpiece_tokenize, vocab = wordpiece.data::wordpiece_vocab(), tokenizer_options = NULL )
tokenize_bert( ..., n_tokens = 64L, increment_index = TRUE, pad_token = "[PAD]", cls_token = "[CLS]", sep_token = "[SEP]", tokenizer = wordpiece::wordpiece_tokenize, vocab = wordpiece.data::wordpiece_vocab(), tokenizer_options = NULL )
... |
One or more character vectors or lists of character vectors.
Currently we support a single character vector, two parallel character
vectors, or a list of length-1 character vectors. If two vectors are
supplied, they are combined pairwise and separated with |
n_tokens |
Integer scalar; the number of tokens expected for each example. |
increment_index |
Logical; if TRUE, add 1L to all token ids to convert from the Python-inspired 0-indexed standard to the torch 1-indexed standard. |
pad_token |
Character scalar; the token to use for padding. Must be present in the supplied vocabulary. |
cls_token |
Character scalar; the token to use at the start of each
example. Must be present in the supplied vocabulary, or |
sep_token |
Character scalar; the token to use at the end of each
segment within each example. Must be present in the supplied vocabulary, or
|
tokenizer |
The tokenizer function to use to break up the text. It must
have a |
vocab |
The vocabulary to use to tokenize the text. This vocabulary must
include the |
tokenizer_options |
A named list of additional arguments to pass on to the tokenizer. |
An object of class "bert_tokens", which is a list containing a matrix of token ids, a matrix of token type ids, and a matrix of token names.
tokenize_bert( c("The first premise.", "The second premise."), c("The first hypothesis.", "The second hypothesis.") )
tokenize_bert( c("The first premise.", "The second premise."), c("The first hypothesis.", "The second hypothesis.") )
Build a BERT-style multi-layer attention-based transformer.
transformer_encoder_bert( embedding_size, intermediate_size = 4 * embedding_size, n_layer, n_head, hidden_dropout = 0.1, attention_dropout = 0.1 )
transformer_encoder_bert( embedding_size, intermediate_size = 4 * embedding_size, n_layer, n_head, hidden_dropout = 0.1, attention_dropout = 0.1 )
embedding_size |
Integer; the dimension of the embedding vectors. |
intermediate_size |
Integer; size of dense layers applied after attention mechanism. |
n_layer |
Integer; the number of attention layers. |
n_head |
Integer; the number of attention heads per layer. |
Numeric; the dropout probability to apply to dense layers. |
|
attention_dropout |
Numeric; the dropout probability to apply in attention. |
Inputs:
With each input token list of length sequence_length
:
input:
optional mask:
Output:
embeddings: list of for each
transformer layer.
weights: list of for
each transformer layer.
emb_size <- 4L seq_len <- 3L n_head <- 2L n_layer <- 5L batch_size <- 2L model <- transformer_encoder_bert( embedding_size = emb_size, n_head = n_head, n_layer = n_layer ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)
emb_size <- 4L seq_len <- 3L n_head <- 2L n_layer <- 5L batch_size <- 2L model <- transformer_encoder_bert( embedding_size = emb_size, n_head = n_head, n_layer = n_layer ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)
Build a single layer of a BERT-style attention-based transformer.
transformer_encoder_single_bert( embedding_size, intermediate_size = 4 * embedding_size, n_head, hidden_dropout = 0.1, attention_dropout = 0.1 )
transformer_encoder_single_bert( embedding_size, intermediate_size = 4 * embedding_size, n_head, hidden_dropout = 0.1, attention_dropout = 0.1 )
embedding_size |
Integer; the dimension of the embedding vectors. |
intermediate_size |
Integer; size of dense layers applied after attention mechanism. |
n_head |
Integer; the number of attention heads per layer. |
Numeric; the dropout probability to apply to dense layers. |
|
attention_dropout |
Numeric; the dropout probability to apply in attention. |
Inputs:
input:
optional mask:
Output:
embeddings:
weights:
emb_size <- 4L seq_len <- 3L n_head <- 2L batch_size <- 2L model <- transformer_encoder_single_bert( embedding_size = emb_size, n_head = n_head ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)
emb_size <- 4L seq_len <- 3L n_head <- 2L batch_size <- 2L model <- transformer_encoder_single_bert( embedding_size = emb_size, n_head = n_head ) # get random values for input input <- array( sample( -10:10, size = batch_size * seq_len * emb_size, replace = TRUE ) / 10, dim = c(batch_size, seq_len, emb_size) ) input <- torch::torch_tensor(input) model(input)