Models#
The Duplex Speech-to-Speech (S2S) collection consists of several model architectures designed to enable conversational AI systems with speech understanding and generation capabilities. These models combine audio perception components with language models and speech synthesis to create end-to-end speech-to-speech systems.
Core Model Architectures#
The collection includes the following core model architectures:
SALM (Speech-Augmented Language Model)#
SALM is a speech-augmented language model that integrates an audio perception module with a pre-trained LLM. The model is designed to understand speech input and generate text responses. This is an implementation of the `SALM paper<https://arxiv.org/abs/2310.09424>`_.
Key components:
Audio Perception Module: Processes audio inputs and converts them into embeddings that can be understood by the LLM. This module leverages a pre-trained ASR model’s encoder.
LLM: A pre-trained large language model that receives the audio embeddings and generates appropriate text responses.
Audio Locator Tag: A special token that serves as a placeholder in the input text, which gets replaced with the audio embeddings during processing.
SALM is particularly useful for: * Speech-to-text applications where high-quality text generation is required * Speech understanding tasks that benefit from the contextual understanding of an LLM * Applications that need to handle mixed text and speech inputs
DuplexS2SModel#
The DuplexS2SModel extends the SALM architecture to enable full speech-to-speech capabilities, adding the ability to generate speech output.
Key components:
Audio Perception Module: Processes audio inputs into embeddings for the LLM.
LLM: Processes audio embeddings and generates text or audio token responses.
Audio Codec: Converts discrete audio tokens into speech waveforms.
Audio Head: Predicts audio tokens for speech generation.
Embed Audio Tokens: Embedding layers for audio token representation.
This model is particularly useful for: * Complete conversational agents that can listen and respond with speech * Voice assistants and interactive dialogue systems * Applications requiring natural-sounding spoken responses
DuplexS2SSpeechDecoderModel#
This model focuses on the speech generation aspect of the duplex system, optimizing the decoder for high-quality speech output.
Key components:
Audio Perception Module: Similar to other models, processes audio into embeddings.
Speech Decoder: A specialized component for generating high-quality speech from text or audio representations.
TransformerARSpeechDecoder: An autoregressive transformer-based decoder for speech generation.
This model is particularly useful for: * Applications focusing on high-quality speech synthesis * Systems where speech generation quality is the primary concern * Specialized voice output applications
Model Components#
Audio Perception Module#
The audio perception module is responsible for converting speech signals into embeddings that can be processed by language models. It typically consists of:
Preprocessor: Converts raw audio waveforms into spectral features
Encoder: Processes these features to create meaningful representations
Modality Adapter: Adapts the encoder outputs to be compatible with the LLM’s input space
Speech Generation#
Speech generation components convert text or token representations back into speech. The collection offers:
TransformerARSpeechDecoder: An autoregressive transformer-based speech decoder
Audio Codec Integration: Works with audio codecs to generate natural speech from discrete tokens
Implementation Details#
The DuplexS2SModel implementation contains several key methods that handle different aspects of the model’s functionality:
Model Initialization#
The constructor (__init__) initializes the following components:
Pretrained ASR encoder/perception: Loads a pretrained NeMo ASR model and adapts it for audio perception
LLM: Loads a pretrained language model using HuggingFace’s AutoModel
Audio codec: Loads a pretrained NeMo audio codec for speech generation
Token prediction heads: Adds separate heads for text token and audio token prediction
Forward Method#
The forward method:
Accepts input representations (sum of audio perception and text embedding outputs)
Runs an offline forward pass through the language model
Generates logits from both text and audio token prediction heads
Returns these logits for loss computation
Training Step#
The training_step method:
Builds input representations using the prepare_inputs method
Runs the forward method to get text and audio logits
Computes losses for both text and audio tokens
Logs training metrics (loss, learning rate, etc.)
Returns the loss for backpropagation
Prepare Inputs#
The prepare_inputs method:
Processes source audio through the perception module (with gradients enabled)
Processes target audio through the audio codec (with gradients disabled)
Truncates source/target audio and target text sequences to have the same sequence length, if needed
Performs additional truncation for sequence lengths to be divisible by tensor parallelism world size (if enabled)
Returns a dictionary with input and label tensors for training
def prepare_inputs(self, batch):
# Process source audio through perception
audio_embs, audio_emb_lens = self.perception(
input_signal=batch["source_audio"],
input_signal_length=batch["source_audio_lens"]
)
# Process target audio through codec (no gradients)
with torch.no_grad():
target_audio_tokens, target_audio_token_lens = self.audio_codec.encode(batch["target_audio"], batch["target_audio_lens"])
# Truncate sequences if needed
# ... (truncation logic)
# Embed text tokens and combine them with audio representations
# ... (text embedding logic)
# Return processed inputs and labels
return {
"audio_embeds": audio_embs,
"input_embeds": input_embeds,
"attention_mask": attention_mask,
"target_text_ids": target_text_ids,
"target_audio_ids": target_audio_ids,
}
Validation#
The validation process:
Clears GPU memory to avoid OOM issues
Loads a scoring ASR model into GPU for evaluation
Initializes metric aggregation for each dataset
Processes each validation dataset separately
Computes BLEU scores for text and ASR-decoded audio
Logs and clears metrics after validation is complete
Scaling Support#
The DuplexS2SModel includes a configure_model method that sets up model parallelism for large-scale training. This method:
Detects the parallelism strategy from the trainer’s device mesh
Applies Fully Sharded Data Parallel (FSDP) sharding to appropriate modules
Applies Tensor Parallelism (TP) and Sequence Parallelism (SP) when configured
Handles model-specific adaptations for different LLM architectures
The scaling approach supports:
Pure FSDP2 for distributing parameters across GPUs
Pure TP/SP for splitting computation across GPUs
2D parallelism combining both approaches
Pretrained Model Usage#
All models in the speechlm2 collection can be instantiated from pretrained checkpoints:
import nemo.collections.speechlm2 as slm
# Load SALM model
salm_model = slm.models.SALM.from_pretrained("path/to/checkpoint")
# Load DuplexS2SModel
duplex_model = slm.models.DuplexS2SModel.from_pretrained("path/to/checkpoint")
# Load DuplexS2SSpeechDecoderModel
decoder_model = slm.models.DuplexS2SSpeechDecoderModel.from_pretrained("path/to/checkpoint")
Model Configuration#
All models in this collection use a configuration-based approach, where a YAML configuration file specifies the model architecture, components, and training parameters. See the configurations documentation for details on these configuration files.
For information about scaling and training these models at scale, see the training and scaling documentation.