NeMo RETRO Model#

The Retrieval-Enhanced Transformer (RETRO) model is an autoregressive language model that takes into account document chunks retrieved from a large corpus when making predictions. The RETRO model has a similar architecture to the GPT model, but it includes an encoder that encodes the retrieved context and cross-attention layers that integrate the context to improve the model’s output. Below is a simple diagram of the RETRO model architecture.

RETRO model architecture

For more detailed information on the model, please refer to the RETRO paper [nlp-retro1] by Deepmind. The NeMo RETRO Model is an open-source implementation of the paper, and it has the following differences/features compared to Deepmind’s proposed implementation:

  1. The NeMo RETRO Model is built on top of NeMo Megatron code, allowing for efficient training of large language models in a cluster environment.

  2. The NeMo RETRO Model uses Faiss [nlp-retro2] as the K$N search library, which can be accelerated by GPUs.

  3. The NeMo RETRO uses RoPe relative positional encoding [nlp-retro4].

  4. The NeMo RETRO uses SentenceTransformers [nlp-retro3] as the retriever encoder.

  5. The NeMo RETRO supports mu-Transfer [nlp-retro5], allowing for scalable training of the RETRO model via Zero-Shot Hyperparameter Transfer.

Quick start#

Steps below demonstrate training and evaluating a NeMo RETRO model

Data pre-processing#

Step 1: Collect training data#

The RETRO model uses two types of data: training data, which typically consists of 64-token chunks, and retrieval data, which typically consists of 128-token chunks. The training data is used to train the model, while the retrieval data is used to supplement the language model. It’s possible to use the same data for both training and retrieval, as long as duplicates are removed properly, as described below. Both types of data are stored in a loose JSON format, with each line containing a single text sample. For example:

The name of the text field of the json can be changed by using the --json-key flag in The other metadata are optional and are not used in training.

Step 2: Convert training data into memory map format#

The loose json is then processed into a binary format for training and retrieval. To convert the json into mmap, cached index file. Set the --dataset-impl flag to retmmap, which is the memory map format dedicated for RETRO model.

An example script to prepare data for RETRO training is:

The RETRO model processes chunked documents using 64 tokens as the default chunk size. The RETRO memory map dataset will add padding tokens to the end of each document to make it a multiple of 64. The --need-pad-id argument adds a padding token to the tokenizer if it doesn’t already have one. The --append-eod argument controls whether to add end-of-document tokens to the preprocessed data, and the --retrieval-db argument indicates whether to create a retrieval database for the preprocessed data. If --retrieval-db is used, it will add an additional 64 padding tokens at the end of the document. The --chunk_size and --workers arguments control the size of the data chunks to be processed and the number of worker processes to use, respectively.

Following is the retro memory map index data format:

‘MMIDRETx00x00’ (header 9 bytes)

1 (version 8 byte)

dtype code 1 (1 byte)

sentence count (8 byte)

chunk size (8 byte)

chunk count (8 byte)

retrieved db 2 (1 byte)

number of tokens for each of sentences ( int32 array)

start of sentence address in byte (int64 array)

start of chunk id (int64 array)

chunk id address in byte (int64 array)

1 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16

2 When building the indexed dataset, we pad each sentence to be a multiple of chunk_size with pad_id from the tokenizer. The number of tokens for each sentence includes the padded token ids. For retrieval data, there is an extra chunk_size padding at the end of each sentence, and the retrieved_db flag is set to True. However, the number of tokens for each sentence excludes this extra chunk_size padding.

Following is the retro memory map binary data format:

token id array for sentence 0,1, 2 … (dtype 3 array)

3 np.uint16 vocab_size < 65500 else np.int32

Step 3: Create Faiss index for retrieval data#

After creating the memory map retrieval data binary file and index files, we can build a Faiss index that can quickly find the K-nearest neighbors of a given chunk ID based on a query embedding vector. Because the retrieval data is typically very large, we break this process down into three steps.

Step 3.1: Train the Faiss index structure#

In this step, it uses a subset of the retrieval data to train a empty Faiss index. An example script is:

This command is used to build an empty Faiss index using the 2000000 training data in pubmed_train_text_document. The all-mpnet-base-v2 sentence transformer model is used to encode the chunk tokens into an embedding vector. The index will be saved in the result directory as pubmed_faiss_learn.index. This command specifies using 8 GPUs to train the Faiss index.

Step 3.2: Add retrieval data into sharding index#

This step adds all the retrieval data to the empty Faiss index created in the previous step. An example script is:

This command breaks the retrieval data into total_shards shards and adds the data in the shard specified by shard_id. The result is saved to a file specified by output_file. In the example above, 10 sharding indexes are created.

Step 3.3: Merge the sharding indexes into final Faiss index#

This step merges all the sharding indexes created in the previous step into the final Faiss index. An example script is:

Step 4: Build KNN index#

During training, it is inefficient to run a query to find the K-nearest neighbor chunk IDs for each training data point. This can be pre-calculated by building a KNN index before training. The KNN index maps the training data chunk IDs to the K-nearest neighbor chunk IDs in the retrieval data. As with building the Faiss index, this process is divided into two steps.

Following is the KNN index data format:

‘KNNRETMx00x00’ (header 9 bytes)

1 (version 8 byte)

K number of neighbors (8 byte)

Number chunks (8 byte)

Map to K retrieval data chunk IDs, shape (number_chunks, K) ( int64 array)

Step 4.1: Build KNN sharding index#

The KNN index is built using the memory-mapped training data created by the script and the Faiss index file for the retrieval data built by the script.

An example script is:

In this example, the training data is broken into total_shards shards, and the KNN index is calculated for the shard specified by shard_id. The result is saved to a file specified by output_file. In the example above, 10 KNN sharding indexes are created.

Use the remove_duplicate flag if the training data and retrieval data are the same to remove neighbors from the same document.

Step 4.2: Merge KNN sharding index into final KNN index#

An example script is:

Train NeMo RETRO Model#

Once the training data, retrieval data, KNN index, and Faiss index are prepared, we are ready to train the RETRO model. In the NeMo implementation, the RETRO model can be pre-trained with or without the mu-Transfer [nlp-retro5] feature. We will introduce both ways.

The table below lists some of the common parameters that can be configured for model pre-training.






the micro batch size used for training



tensor model parallel size



token sequence length



the chunk size used to retrieve



total number of encoder layers



total number of decoder layers



layer numbers for cross attention in encoder



layer numbers for chunked cross attention in decoder



whether to add the absolute position encoding



model hidden size



model FFN hidden size. Usually 4 * hidden_size



number of attention heads



standard deviation of the zero mean normal distribution used for weight initialization



dropout probability for hidden state transformer



dropout probability in the attention layer



dropout probability in the feed-forward layer

Option 1: Train the NeMo RETRO model without mu-Transfer#

An example RETRO pre-training script is:

During the training, launch Tensorboard to monitor training like so:


Weights and Biases (WandB) is supported too. Add exp_manager.create_wandb_logger=True to the model training arguments to enable it.

After the training, the model nemo file can be found at the result checkpoint directory.

Option 2: Train the NeMo RETRO model with mu-Transfer#

mu-Transfer [nlp-retro5] paper proposed a method to zero-shot transfer hyperparameter to train a larger model. This can be done in 3 steps in NeMo RETRO implementation.

Step 1. find optimal hyper parameter for a small base model#

Use the pre-training code in Option 1, either manually or automatically ind a set of optimal hyperparameter for a small base RETRO model. This is can be done cheaply ans fast due to the small model size.

Step 2. calculate the shape file that can be used to run mu-Transfer#

The shape file determines which hyperparameters will be scaled up, allowing the model to adjust the learning rate, weight scaling factor, etc.

Here is an example shape file calculation script:

In this example, the base_model refers to the small base model for which an optimal set of hyperparameters has been determined. The delta_model refers to a model with certain hyperparameters that have been scaled up or down. In this case, the hidden_size and ffn_hidden_size have been changed in the delta_model, allowing these two parameters to be scaled freely later.

Step 3. Pretrain mu-Transfer RETRO model#

Once the shape file is created, we can start training a RETRO model. The model training can be scale up freely using the hyperparameters specified by the delta model and the shape file.

An example mu-Transfer pre-training script is:


We have chosen to use muadamw as the optimizer for use with the mu-transfer method. Currently, only muadam and muadamw are supported.

Similarly to the pre-training in Option 1, the model nemo file can be found at the result checkpoint directory after training is complete.

Run NeMo RETRO Model Inference#

Once the NeMo RETRO model has been trained, we can put it into inference mode and experiment with it. During inference, we are not limited to the static Faiss index that we built earlier for KNN queries. We can feed any external data to the model as retrieval context. NeMo RETRO implementation supports dynamic retrieval service, allowing users to add, reset, and query new documents on the fly.

We have built a simple web client that makes it easy for users to play around with the model. Here is an example script to launch the server:

Set the retro_model_file to use the nemo file generated in the pre-training step. After launching the server, copy-paste the URL from the terminal into your browser. Use the specified username and password to log in and have fun experimenting with the RETRO model.



Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George van den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, and others. Improving language models by retrieving from trillions of tokens. arXiv preprint arXiv:2112.04426, 2021.


Hervé Jégou, Matthijs Douze, Jeff Johnson, Lucas Hosseini, and Chengqi Deng. Faiss: similarity search and clustering of dense vectors library. Astrophysics Source Code Library, pages ascl–2210, 2022.


Nils Reimers and Iryna Gurevych. Sentence-bert: sentence embeddings using siamese bert-networks. arXiv preprint arXiv:1908.10084, 2019.


Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. Roformer: enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864, 2021.


Greg Yang, Edward J Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tensor programs v: tuning large neural networks via zero-shot hyperparameter transfer. arXiv preprint arXiv:2203.03466, 2022.