NeMo TTS API#
Model Classes#
MagpieTTS (Codec-based TTS)#
MagpieTTS is an end-to-end TTS model that generates audio codes from transcript and optional context (audio or text). It supports multiple architectures (e.g. multi-encoder context, decoder context) and can be used for standard, long-form, and streaming inference.
- class nemo.collections.tts.models.MagpieTTSModel(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
ModelPTMagpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context audio/text
Supports multiple model types:
multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio. Only one of context audio or contex text is supported.
decoder_context_tts: Text goes into the encoder; context & target audio go to the decoder. Also supports text context. Supports fixed sized context so we set context_duration_min and context_duration_max to the same value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model.
decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and the decoder input.
- add_special_tokens(
- codes,
- codes_len,
- bos_id,
- eos_id,
- num_bos_tokens=1,
- num_eos_tokens=1,
- check_frame_stacking_config_validity()[source]#
Check if the configuration is compatible with frame stacking.
- clear_forbidden_logits(
- logits: Tensor,
- forbid_audio_eos: bool = False,
Sets logits of forbidden tokens to -inf so they will never be sampled. Specifically, we forbid sampling of all special tokens except AUDIO_EOS which is allowed by default.
- Parameters:
logits – (B, C, num_audio_tokens_per_codebook)
forbid_audio_eos (bool, optional) – If True, also forbid AUDIO_EOS tokens from being sampled. Default: False.
- compute_local_transformer_logits(
- dec_out,
- audio_codes_target,
- targets_offset_by_one=False,
Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes. This function is used in training and validation, not inference/sampling. The sequence layout is slightly different between AR and MG modes, as shown in the diagram below, (using an 8-codebook setup as an example): +————+———+———+———+———+———+———+———+———+———+ | AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none | | codebook | | | | | | | | | | +————+———+———+———+———+———+———+———+———+———+ | MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | | codebook | | | | | | | | | | +————+———+———+———+———+———+———+———+———+———+ | input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | | codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | +————+———+———+———+———+———+———+———+———+———+ | seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +————+———+———+———+———+———+———+———+———+———+
- Parameters:
dec_out – (B, T’, E)
audio_codes_target – (B, C, T’)
targets_offset_by_one – bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit)
- compute_loss(
- logits,
- audio_codes,
- audio_codes_lens,
- mask_tokens_mask=None,
- frame_stacking_factor=1,
Computes the audio codebook loss. Used by:
The main Magpie-TTS transformer
The local transformer, for both autoregressive and MaskGit methods
- Parameters:
logits – (B, T’, num_codebooks * num_tokens_per_codebook)
audio_codes – (B, C, T’)
audio_codes_lens – (B,)
mask_tokens_mask – (B, C, T’) True for tokens that were replaced with the MASK_TOKEN and should therefore be the only ones included in the loss computation (for MaskGit).
frame_stacking_factor – int, the stacking factor used in the model
- construct_inference_prior(
- prior_epsilon,
- cross_attention_scores,
- text_lens,
- text_time_step_attended,
- attended_timestep_counter,
- unfinished_texts,
- finished_texts_counter,
- end_indices,
- lookahead_window_size,
- batch_size,
- construct_longform_inference_prior(
- prior_epsilon: float,
- cross_attention_scores: Tensor,
- text_lens: Tensor,
- text_time_step_attended: List[int],
- attended_timestep_counter: List[Dict[int, int]],
- unfinished_texts: Dict[int, bool],
- finished_texts_counter: Dict[int, int],
- end_indices: Dict[int, int],
- chunk_end_dict: Dict[int, int],
- batch_size: int,
- left_offset: List[int] | None = None,
Construct attention prior for longform inference with chunked text.
Builds a soft attention prior that guides the decoder to attend to appropriate text positions, preventing attention drift and encouraging monotonic progression.
- Parameters:
prior_epsilon – Base probability for non-targeted positions.
cross_attention_scores – Attention scores for shape/device inference. Shape: (effective_batch, text_length).
text_lens – Length of text for each batch item. Shape: (batch_size,).
text_time_step_attended – Most attended text position (absolute) per batch item.
attended_timestep_counter – Per-batch dicts tracking attention counts per timestep.
unfinished_texts – Updated in-place. True if text still being processed.
finished_texts_counter – Updated in-place. Counts consecutive near-end timesteps.
end_indices – Batch indices that have reached end-of-sequence.
chunk_end_dict – Batch indices that have reached chunk end.
batch_size – Number of items in the batch.
left_offset – Chunk offset for each batch item. Defaults to zeros.
- Returns:
Tuple of (attention_prior, unfinished_texts, finished_texts_counter).
- create_longform_chunk_state(
- batch_size: int,
Create fresh state for longform inference over a batch.
This method creates a LongformChunkState dataclass instance that tracks mutable state across multiple calls to generate_long_form_speech() when processing long text in chunks.
The returned state object should be: 1. Created once per batch by the inference runner 2. Passed to each call of generate_long_form_speech() 3. Updated in-place during generation
- Parameters:
batch_size – Number of items in the batch.
- Returns:
LongformChunkState with initialized state for the batch.
Example
>>> chunk_state = model.create_longform_chunk_state(batch_size=4) >>> for chunk in text_chunks: ... output = model.generate_long_form_speech(batch, chunk_state, ...)
- detect_eos(
- audio_codes_multinomial,
- audio_codes_argmax,
- eos_detection_method,
Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack that triggers EOS detection, or float(‘inf’) if no EOS is found. :param audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples :param audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples :param eos_detection_method: EOS detection method
- Returns:
index (within the frame stack) of the first frame with EOS, or float(‘inf’) if no EOS is found
- do_tts(
- transcript: str,
- language: str = 'en',
- apply_TN: bool = False,
- use_cfg: bool = True,
- speaker_index: int | None = None,
Generate speech from raw text transcript.
This is a convenience method for single-utterance text-to-speech synthesis. For batch processing, use infer_batch directly. Only supports baked context embedding context injection, NO audio conditioning and text conditioning. Custom voice generation is not supported by this method.
- Parameters:
transcript – Raw text to synthesize.
language – Language code for text normalization and tokenization. Supported values depend on model’s tokenizer configuration. Common: “en” (English), “de” (German), “es” (Spanish), etc.
apply_TN – Whether to apply text normalization to the transcript. If True, uses nemo_text_processing for normalization.
use_cfg – Whether to use classifier-free guidance.
speaker_index – Speaker index for multi-speaker baked embeddings. Valid range: [0, num_baked_speakers - 1]. If None, uses speaker 0. Only applicable for models with baked context embeddings.
- Returns:
audio: Generated audio waveform. Shape: (1, T_audio). audio_len: Length of generated audio in samples. Shape: (1,).
- Return type:
Tuple of (audio, audio_len) where
- Raises:
ValueError – If model does not have a baked context embedding.
ValueError – If speaker_index is out of valid range.
ImportError – If apply_TN=True but nemo_text_processing is not installed.
Example
>>> # If text does not need to be normalized >>> audio, audio_len = model.do_tts("Hello, how are you today?") >>> >>> # If text needs to be normalized >>> audio, audio_len = model.do_tts( ... "Hello, how are you today?", ... apply_TN=True, ... ) >>> >>> # Use a specific speaker (for multi-speaker models) >>> audio, audio_len = model.do_tts( ... "Hello!", speaker_index=2 ... )
- find_eos_frame_index(
- codes,
- eos_detection_method,
Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack that contains an EOS token across any codebook, or None if no EOS is found. :param codes: (num_codebooks, frame_stacking_factor)
- Returns:
index (within the frame stack) of the first frame with EOS, or float(‘inf’) if no EOS is found
- forward(
- dec_input_embedded,
- dec_input_mask,
- cond,
- cond_mask,
- attn_prior,
- multi_encoder_mapping,
Forward pass through the decoder transformer, followed by a linear projection to audio codebook logits.
- Parameters:
dec_input_embedded (torch.Tensor) – Embedded decoder input of shape (B, T, C).
dec_input_mask (torch.Tensor) – Boolean mask for decoder input of shape (B, T).
cond (torch.Tensor or List[torch.Tensor]) – Conditioning tensor(s) for cross-attention.
cond_mask (torch.Tensor or List[torch.Tensor]) – Mask(s) for conditioning tensor(s).
attn_prior (torch.Tensor or None) – Prior attention weights for cross-attention.
multi_encoder_mapping (List[Optional[int]] or None) – Per-layer mapping to conditioning inputs.
- Returns:
all_code_logits (torch.Tensor): Logits of shape (B, T’, num_codebooks * num_tokens_per_codebook).
attn_probabilities (list): Attention probabilities from each decoder layer.
dec_output (torch.Tensor): Raw decoder output of shape (B, T’, d_model).
moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled, a list of dicts (one per layer) each containing:
’router_logits’ (torch.Tensor): Raw router logits (B, T, num_experts).
’router_probs’ (torch.Tensor): Router probabilities (B, T, num_experts).
’expert_indices’ (torch.Tensor): Selected expert indices (B, T, top_k).
- Return type:
Tuple of
- generate_long_form_speech(
- batch,
- chunk_state: LongformChunkState,
- end_of_text,
- beginning_of_text,
- use_cfg=True,
- use_local_transformer_for_inference=False,
- maskgit_n_steps=3,
- maskgit_noise_scale=0.0,
- maskgit_fixed_schedule=None,
- maskgit_dynamic_cfg_scale=False,
- maskgit_sampling_type=None,
Generates speech for long-form text by progressively shifting through text tokens.
This method processes long text inputs by generating a fixed number of audio tokens per text token, then shifting to the next text token. It maintains a sliding window over text and audio histories, tracking how many audio tokens were generated for each text position. The behaviour of this function is strongly dependent on self.inference_parameters.
- Parameters:
batch (dict) – Input batch containing ‘text’ and ‘text_lens’.
chunk_state (LongformChunkState) – Mutable state object tracking history across chunks. Created via model.create_longform_chunk_state() and updated in-place.
end_of_text (List[bool]) – Whether entire text has been provided for each batch item.
beginning_of_text (bool) – Whether this is the first chunk.
use_cfg (bool) – Whether to use classifier-free guidance.
use_local_transformer_for_inference (bool) – Whether to use local transformer for sampling.
maskgit_n_steps (int) – Number of MaskGit refinement steps.
maskgit_noise_scale (float) – Noise scale for MaskGit sampling.
maskgit_fixed_schedule (Optional[List[int]]) – Fixed schedule for MaskGit.
maskgit_dynamic_cfg_scale (bool) – Whether to use dynamic CFG scale in MaskGit.
maskgit_sampling_type (Optional[str]) – Type of MaskGit sampling.
- Returns:
Contains predicted_codes, predicted_codes_lens, and empty audio fields.
- Return type:
InferBatchOutput
- get_baked_context_embeddings_batch(
- batch_size: int,
- speaker_indices: int | List[int] | Tensor | None = None,
Get baked context embeddings for a batch, with per-element speaker selection.
- Parameters:
batch_size – Number of elements in the batch.
speaker_indices – Speaker selection. Can be: - None: Use first speaker (index 0) for all batch elements - int: Same speaker for all batch elements - List[int] or Tensor: One speaker index per batch element (length must match batch_size)
- Returns:
embeddings: (B, T, D) tensor
lengths: (B,) tensor with embedding lengths per batch element
- Return type:
Tuple of (embeddings, lengths) where
- Raises:
ValueError – If speaker_indices length doesn’t match batch_size or indices are out of range.
- get_cross_attention_scores(
- attn_probs,
- filter_layers=None,
Returns the cross attention probabilities for the last audio timestep
- get_inference_attention_plots(
- cross_attention_scores_all_timesteps,
- all_heads_cross_attn_scores_all_timesteps,
- text_lens,
- predicted_codes_lens,
- batch_size,
- compute_all_heads_attn_maps,
- last_attended_timestep,
- get_lhotse_dataloader(
- dataset_cfg,
- mode='train',
- get_most_attended_text_timestep(
- alignment_attention_scores,
- last_attended_timesteps,
- text_lens,
- lookahead_window_size,
- attended_timestep_counter,
- batch_size,
- left_offset=[],
Returns the most attended timestep for each batch item
This method identifies which text token is most attended to within a lookahead window, starting from the last attended timestep. It includes logic to detect attention sinks (tokens attended to excessively) and move past them. The method also tracks how many times each timestep has been attended.
- Parameters:
alignment_attention_scores (torch.Tensor) – Attention scores between audio and text tokens. Shape: (batch_size, text_length).
last_attended_timesteps (list) – List containing the last attended timestep for each batch item. The last element [-1] should be a list/tensor of length batch_size.
text_lens (torch.Tensor) – Length of text sequence for each batch item. Shape: (batch_size,).
lookahead_window_size (int) – Size of the forward-looking window to search for the next attended timestep. Determines how far ahead from the last attended timestep to look.
attended_timestep_counter (list) – List of dictionaries (one per batch item) tracking how many times each timestep has been attended. Used to detect attention sinks.
batch_size (int) – Number of items in the batch.
left_offset (list, optional) – List of offsets to adjust timestep indices for each batch item, used in longform inference when text is provided in chunks. Relevant only in longform generation.
- Returns:
- A tuple containing:
text_time_step_attended (list): List of integers, one per batch item, indicating the most attended text timestep for that item.
attended_timestep_counter (list): Updated counter tracking attendance frequency for each timestep across all batch items.
- Return type:
tuple
- property has_baked_context_embedding: bool#
Check if the model has a baked context embedding.
- Returns:
True if baked_context_embedding is set with valid dimensions.
- infer_batch(
- batch,
- use_cfg=False,
- return_cross_attn_probs=False,
- compute_all_heads_attn_maps=False,
- use_local_transformer_for_inference=False,
- maskgit_n_steps=3,
- maskgit_noise_scale=0.0,
- maskgit_fixed_schedule=None,
- maskgit_dynamic_cfg_scale=False,
- maskgit_sampling_type=None,
The behaviour of this function is strongly dependent on self.inference_parameters
- classmethod list_available_models() List[PretrainedModelInfo][source]#
Should list all pre-trained models available via NVIDIA NGC cloud. Note: There is no check that requires model names and aliases to be unique. In the case of a collision, whatever model (or alias) is listed first in the this returned list will be instantiated.
- Returns:
A list of PretrainedModelInfo entries
- load_state_dict(state_dict, strict=True)[source]#
Modify load_state_dict so that we don’t restore weights to _speaker_verification_model and _codec_model when strict is True. When strict is False, we can call pytorch’s load_state_dict. When strict is True, we loop through all parameters and rename them to enable loading.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts model_type that is no longer supported and can likely be removed in a future version.
Also handles loading baked context embeddings. If the checkpoint contains baked_speaker_embedding.weight, context_encoder weights are not expected to be present. The embedding is stored in flattened format (N, T*D) and reconstructed to (N, T, D) at inference time using stored T and D dimensions.
- local_transformer_sample_autoregressive(
- dec_output: Tensor,
- temperature: float = 0.7,
- topk: int = 80,
- unfinished_items: Dict[int, bool] = {},
- finished_items: Dict[int, bool] = {},
- use_cfg: bool = False,
- cfg_scale: float = 1.0,
- use_kv_cache: bool = True,
- forbid_audio_eos: bool = False,
Sample audio codes autoregressively across codebooks using the local transformer. Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
The sequence is initialized with the primary decoder’s hidden output as the only input and is gradually extended a code for one codebook at a time, appending the sampled code as input sequence for the next step. At the last step the sequence is num_codebooks long. If frame stacking is enabled, codes for all frames in the stack are sampled as one long sequence and the final sequence length is num_codebooks * frame_stacking_factor codes long.
Special handling: * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled * forces / forbids EOS for finished / unfinished items respectively * optionally, globally forbids audio EOS (useful early in the generation process)
- Parameters:
dec_output (torch.Tensor) – Decoder output tensor with shape (B, E) where B is batch size and E is primary decoder’s embedding dimension.
temperature (float, optional) – Sampling temperature.
topk (int, optional) – Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional) – Dictionary containing indices of batch items that we are confident have not completed generation. For these items, audio EOS sampling is forbidden.
finished_items (dict, optional) – Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced.
use_cfg (bool, optional) – Whether to use classifier-free guidance. If True, expects batch size to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional) – Scale factor for classifier-free guidance. Only used if use_cfg=True.
use_kv_cache (bool, optional) – Whether to use key-value caching in the transformer.
forbid_audio_eos (bool, optional) – Whether to globally forbid audio EOS for the entire batch.
- Returns:
- Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
where B is batch size (or actual_batch_size if use_cfg=True).
- Return type:
- local_transformer_sample_maskgit(
- dec_output: Tensor,
- temperature: float = 0.7,
- topk: int = 80,
- unfinished_items: Dict[int, bool] = {},
- finished_items: Dict[int, bool] = {},
- use_cfg: bool = False,
- cfg_scale: float = 1.0,
- n_steps: int = 3,
- noise_scale: float = 0.0,
- fixed_schedule: List[int] | None = None,
- dynamic_cfg_scale: bool = False,
- sampling_type: str | None = None,
- forbid_audio_eos: bool = False,
Sample audio codes for the current timestep using MaskGit-like iterative prediction with the local transformer. If frame-stacking is enabled, the codes for all frames in the stack are sampled, treated as one long sequence.
The MaskGit process starts with all positions masked and iteratively unmasks the most confident positions over multiple steps. By “masked” we mean that a dedicated MASK token is used (as opposed to attention masking). The LT in this case is a non-causal transformer decoder. At each step the model predicts all positions at once. Of those predictions, a subset of the most confident previously-masked positions is kept and unmasked in the next step. The number of positions that are unmasked at each step is determined by the unmasking schedule. We support a cosine schedule and a fixed schedule provided by the user.
Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
Special handling:
forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
forces / forbids EOS for finished / unfinished items respectively
optionally, globally forbids audio EOS for all items in the batch. This is useful early in the generation process.
supports different unmasking methods, see sampling_type argument for details.
- Parameters:
dec_output (torch.Tensor) – Decoder output tensor with shape (B, E) where B is batch size and E is primary decoder’s embedding dimension.
temperature (float, optional) – Sampling temperature
topk (int, optional) – Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional) – Dictionary containing indices of batch items that we are confident have not completed generation. For these items, audio EOS sampling is forbidden.
finished_items (dict, optional) – Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced.
use_cfg (bool, optional) – Whether to use classifier-free guidance. If True, expects batch size to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional) – Scale factor for classifier-free guidance. Only used if use_cfg=True.
n_steps (int, optional) – Number of iterative refinement steps for MaskGit sampling.
noise_scale (float, optional) – Scale factor for noise to add to confidence scores during sampling (experimental).
fixed_schedule (list, optional) – Fixed schedule for number of tokens to unmask at each step. If None, uses cosine schedule.
dynamic_cfg_scale (bool, optional) – Whether to dynamically adjust CFG scale during sampling (experimental).
sampling_type (str, optional) –
Type of sampling strategy. Options are: [“default”, “causal”, “purity_causal”, “purity_default”].
Purity refers to “purity sampling” from https://arxiv.org/abs/2304.01515. If “purity” is not specified, confidence sampling is used as in the original MaskGit paper.
”default”/”causal”: Controls the order of unmasking across frames when frame-stacking is enabled. If “causal” is specified, frames are unmasked in causal order. “default” doesn’t impose any constraints on the unmasking order.
forbid_audio_eos (bool, optional) – Whether to globally forbid audio EOS for the entire batch.
- Returns:
Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
- Return type:
- log_attention_probs(
- attention_prob_matrix,
- audio_codes_lens,
- text_lens,
- prefix='',
- dec_context_size=0,
- log_val_audio_example(
- logits,
- target_audio_codes,
- audio_codes_lens,
- context_audio_codes=None,
- context_audio_codes_lens=None,
- maskgit_create_random_mask(codes)[source]#
Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN.
- property num_baked_speakers: int#
Return number of baked speakers.
- Returns:
0 if no baked embedding, N for embedding with N speakers.
- on_validation_epoch_end()[source]#
Default DataLoader for Validation set which automatically supports multiple data loaders via multi_validation_epoch_end.
If multi dataset support is not required, override this method entirely in base class. In such a case, there is no need to implement multi_validation_epoch_end either.
Note
If more than one data loader exists, and they all provide val_loss, only the val_loss of the first data loader will be used by default. This default can be changed by passing the special key val_dl_idx: int inside the validation_ds config.
- Parameters:
outputs – Single or nested list of tensor outputs from one or more data loaders.
- Returns:
A dictionary containing the union of all items from individual data_loaders, along with merged logs from all data loaders.
- pad_audio_codes(audio_codes: Tensor)[source]#
Pads the time dimension of the audio codes to a multiple of the frame stacking factor. :param audio_codes: B, C, T :type audio_codes: torch.Tensor :param frame_stacking_factor: The factor that frames will be stacked by. :type frame_stacking_factor: int :param pad_token: The token ID to pad with. :type pad_token: int
- Returns:
B, C, T_padded
- prepare_context_tensors(
- batch: Dict[str, Tensor],
Prepare all context tensors for the decoder.
This method orchestrates text encoding, context extraction, and model-type-specific processing to prepare tensors for decoder inference or training.
- Parameters:
batch – Dictionary containing: - ‘text’: Text token IDs. Shape: (B, T_text). - ‘text_lens’: Text lengths. Shape: (B,). - ‘context_audio_codes’ or ‘context_audio’: Context audio. - ‘align_prior_matrix’ (optional): Beta-binomial attention prior. - ‘speaker_indices’ (optional): Speaker IDs for multi-speaker models. - Text conditioning fields if use_text_conditioning_encoder is True.
- Returns:
ContextTensorsOutput dataclass containing all prepared tensors.
- Raises:
ValueError – If model_type is not supported.
- prepare_dummy_cond_for_cfg(
- cond,
- cond_mask,
- additional_decoder_input,
- additional_dec_mask,
- sample_codes_from_logits(
- all_code_logits_t: Tensor,
- temperature: float = 0.7,
- topk: int = 80,
- unfinished_items: Dict[int, bool] = {},
- finished_items: Dict[int, bool] = {},
- forbid_audio_eos: bool = False,
Sample codes for all codebooks at a given timestep. Uses multinomial sampling with temperature and top-k. If frame stacking is on (i.e. frame_stacking_factor > 1), this function will sample across the entire frame stack.
Special handling: * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled * forces / forbids EOS for finished / unfinished items respectively * optionally, globally forbids audio EOS (useful early in the generation process)
- Parameters:
all_code_logits_t (torch.Tensor) – Logits at a given timestep with shape (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
temperature (float, optional) – Sampling temperature
topk (int, optional) – Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional) – Dictionary containing indices of batch
items (items that we are confident have not completed generation. For these) – sampling is forbidden.
EOS (audio) – sampling is forbidden.
finished_items (dict, optional) – Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced.
forbid_audio_eos (bool, optional) – Whether to globally forbid audio EOS for the entire batch.
- Returns:
Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
- Return type:
- setup_test_data(dataset_cfg)[source]#
(Optionally) Setups data loader to be used in test
- Parameters:
test_data_layer_config – test data layer parameters.
Returns:
- setup_training_data(dataset_cfg)[source]#
Setups data loader to be used in training
- Parameters:
train_data_layer_config – training data layer parameters.
Returns:
- setup_validation_data(dataset_cfg)[source]#
Setups data loader to be used in validation :param val_data_layer_config: validation data layer parameters.
Returns:
- state_dict(
- destination=None,
- prefix='',
- keep_vars=False,
Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model from the checkpoint. The codec model is saved in a separate checkpoint.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts model_type that is no longer supported and can likely be removed in a future version.
If the model has a baked context embedding, the context_encoder weights are also excluded since they are no longer needed for inference.
- test_step(batch, batch_idx)[source]#
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- training_step(batch, batch_idx)[source]#
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch, batch_idx)[source]#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
Mel-Spectrogram Generators#
- class nemo.collections.tts.models.FastPitchModel(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
SpectrogramGenerator,Exportable,FastPitchAdapterModelMixinFastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text.
- configure_callbacks()[source]#
Configure model-specific callbacks. When the model gets attached, e.g., when
.fit()or.test()gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacksargument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sureModelCheckpointcallbacks run last.- Returns:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example:
def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]
- property disabled_deployment_input_names#
Implement this method to return a set of input names disabled for export
- forward(
- *,
- text,
- durs=None,
- pitch=None,
- energy=None,
- speaker=None,
- pace=1.0,
- spec=None,
- attn_prior=None,
- mel_lens=None,
- input_lens=None,
- reference_spec=None,
- reference_spec_lens=None,
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- generate_spectrogram(
- tokens: tensor,
- speaker: int | None = None,
- pace: float = 1.0,
- reference_spec: tensor | None = None,
- reference_spec_lens: tensor | None = None,
Accepts a batch of text or text_tokens and returns a batch of spectrograms
- Parameters:
tokens – A torch tensor representing the text to be generated
- Returns:
spectrograms
- input_example(max_batch=1, max_dim=44)[source]#
Generates input examples for tracing etc. :returns: A tuple of input examples.
- property input_types#
Define these to enable input neural type checks
- interpolate_speaker(
- original_speaker_1,
- original_speaker_2,
- weight_speaker_1,
- weight_speaker_2,
- new_speaker_id,
This method performs speaker interpolation between two original speakers the model is trained on.
- Inputs:
original_speaker_1: Integer speaker ID of first existing speaker in the model original_speaker_2: Integer speaker ID of second existing speaker in the model weight_speaker_1: Floating point weight associated in to first speaker during weight combination weight_speaker_2: Floating point weight associated in to second speaker during weight combination new_speaker_id: Integer speaker ID of new interpolated speaker in the model
- classmethod list_available_models() List[PretrainedModelInfo][source]#
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud. :returns: List of available pre-trained models.
- property output_types#
Define these to enable output neural type checks
- parse(
- str_input: str,
- normalize=True,
A helper function that accepts raw python strings and turns them into a tensor. The tensor should have 2 dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor should represent either tokenized or embedded text, depending on the model.
Note that some models have normalize parameter in this function which will apply normalizer if it is available.
- property parser#
- property tb_logger#
Vocoders#
- class nemo.collections.tts.models.HifiGanModel(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
Vocoder,ExportableHiFi-GAN model (https://arxiv.org/abs/2010.05646) that is used to generate audio from mel spectrogram.
- configure_callbacks()[source]#
Configure model-specific callbacks. When the model gets attached, e.g., when
.fit()or.test()gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacksargument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sureModelCheckpointcallbacks run last.- Returns:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example:
def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]
- convert_spectrogram_to_audio(
- spec: tensor,
Accepts a batch of spectrograms and returns a batch of audio.
- Parameters:
spec – [‘B’, ‘n_freqs’, ‘T’], A torch tensor representing the spectrograms to be vocoded.
- Returns:
audio
- forward(*, spec)[source]#
Runs the generator, for inputs and outputs see input_types, and output_types
- forward_for_export(spec)[source]#
Runs the generator, for inputs and outputs see input_types, and output_types
- input_example(max_batch=1, max_dim=256)[source]#
Generates input examples for tracing etc. :returns: A tuple of input examples.
- property input_types#
Define these to enable input neural type checks
- classmethod list_available_models() Dict[str, str] | None[source]#
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud. :returns: List of available pre-trained models.
- load_state_dict(state_dict, strict=True)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- property max_steps#
- on_train_epoch_end() None[source]#
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- property output_types#
Define these to enable output neural type checks
Codecs#
- class nemo.collections.tts.models.AudioCodecModel(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
ModelPT- property codebook_size#
- configure_callbacks()[source]#
Configure model-specific callbacks. When the model gets attached, e.g., when
.fit()or.test()gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacksargument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sureModelCheckpointcallbacks run last.- Returns:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example:
def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]
- decode( ) Tuple[Tensor, Tensor][source]#
Convert discrete tokens into a continuous time-domain signal.
- Parameters:
tokens – discrete tokens for each codebook for each time frame, shape (batch, number of codebooks, number of frames)
tokens_len – valid lengths, shape (batch,)
- Returns:
Decoded output audio in the time domain and its length in number of samples audio_len. Note that audio_len will be a multiple of self.samples_per_frame.
- decode_audio( ) Tuple[Tensor, Tensor][source]#
Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation.
- Parameters:
inputs – encoded signal
input_len – valid length for each example in the batch
- Returns:
Decoded output audio in the time domain and its length in number of samples audio_len. Note that audio_len will be a multiple of self.samples_per_frame.
- dequantize( ) Tensor[source]#
Convert the discrete tokens into a continuous encoded representation.
- Parameters:
tokens – discrete tokens for each codebook for each time frame
tokens_len – valid length of each example in the batch
- Returns:
Continuous encoded representation of the discrete input representation.
- property disc_update_prob: float#
Probability of updating the discriminator.
- property dtype#
- encode( ) Tuple[Tensor, Tensor][source]#
Convert input time-domain audio signal into a discrete representation (tokens).
- Parameters:
audio – input time-domain signal, shape (batch, number of samples)
audio_len – valid length for each example in the batch, shape (batch size,)
sample_rate – sample rate of input audio (int)
- Returns:
Tokens for each codebook for each frame, shape (batch, number of codebooks, number of frames), and the corresponding valid lengths, shape (batch,)
- encode_audio( ) Tuple[Tensor, Tensor][source]#
Apply encoder on the input audio signal. Input will be padded with zeros so the last frame has full self.samples_per_frame samples.
- Parameters:
audio – input time-domain signal
audio_len – valid length for each example in the batch
sample_rate – sample rate of input audio (int)
- Returns:
Encoder output encoded and its length in number of frames encoded_len
- forward( ) Tuple[Tensor, Tensor][source]#
Apply encoder, quantizer, decoder on the input time-domain signal.
- Parameters:
audio – input time-domain signal
audio_len – valid length for each example in the batch
sample_rate – sample rate of input audio (int)
- Returns:
Reconstructed time-domain signal output_audio and its length in number of samples output_audio_len.
- classmethod list_available_models() List[PretrainedModelInfo][source]#
Should list all pre-trained models available via NVIDIA NGC cloud. Note: There is no check that requires model names and aliases to be unique. In the case of a collision, whatever model (or alias) is listed first in the this returned list will be instantiated.
- Returns:
A list of PretrainedModelInfo entries
- load_state_dict(state_dict, strict=True)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- property max_steps#
- property num_codebooks#
- on_train_epoch_end()[source]#
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- pad_audio(audio, audio_len, samples_per_frame)[source]#
Zero pad the end of the audio so that we do not have a partial end frame. The output will be zero-padded to have an integer number of frames of length self.samples_per_frame.
- Parameters:
audio – input time-domain signal
audio_len – valid length for each example in the batch
- Returns:
Padded time-domain signal padded_audio and its length padded_len.
- quantize( ) Tensor[source]#
Quantize the continuous encoded representation into a discrete representation for each frame.
- Parameters:
encoded – encoded signal representation
encoded_len – valid length of the encoded representation in frames
- Returns:
A tensor of tokens for each codebook for each frame.
- should_update_disc(batch_idx) bool[source]#
Decide whether to update the descriminator based on the batch index and configured discriminator update period.
- state_dict(
- destination=None,
- prefix='',
- keep_vars=False,
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
Noneare not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()also accepts positional arguments fordestination,prefixandkeep_varsin order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destinationas it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDictwill be created and returned. Default:None.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''.keep_vars (bool, optional) – by default the
Tensors returned in the state dict are detached from autograd. If it’s set toTrue, detaching will not be performed. Default:False.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
Base Classes#
The classes below are the base of the TTS pipeline.
- class nemo.collections.tts.models.base.MelToSpec(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
ModelPT,ABCA base class for models that convert mel spectrograms to linear (magnitude) spectrograms
- abstractmethod convert_mel_spectrogram_to_linear(
- mel: tensor,
- **kwargs,
Accepts a batch of spectrograms and returns a batch of linear spectrograms
- Parameters:
mel – A torch tensor representing the mel spectrograms [‘B’, ‘mel_freqs’, ‘T’]
- Returns:
A torch tensor representing the linear spectrograms [‘B’, ‘n_freqs’, ‘T’]
- Return type:
spec
- class nemo.collections.tts.models.base.SpectrogramGenerator(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
NeedsNormalizer,ModelPT,ABCBase class for all TTS models that turn text into a spectrogram
- abstractmethod generate_spectrogram(
- tokens: tensor,
- **kwargs,
Accepts a batch of text or text_tokens and returns a batch of spectrograms
- Parameters:
tokens – A torch tensor representing the text to be generated
- Returns:
spectrograms
- classmethod list_available_models() List[PretrainedModelInfo][source]#
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud. :returns: List of available pre-trained models.
- abstractmethod parse(
- str_input: str,
- **kwargs,
A helper function that accepts raw python strings and turns them into a tensor. The tensor should have 2 dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor should represent either tokenized or embedded text, depending on the model.
Note that some models have normalize parameter in this function which will apply normalizer if it is available.
- class nemo.collections.tts.models.base.Vocoder(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
ModelPT,ABCA base class for models that convert spectrograms to audios. Note that this class takes as input either linear or mel spectrograms.
Dataset Processing Classes#
- class nemo.collections.tts.data.dataset.TTSDataset(
- manifest_filepath: str | Path | List[str] | List[Path],
- sample_rate: int,
- text_tokenizer: BaseTokenizer | Callable[[str], List[int]],
- tokens: List[str] | None = None,
- text_normalizer: Normalizer | Callable[[str], str] | None = None,
- text_normalizer_call_kwargs: Dict | None = None,
- text_tokenizer_pad_id: int | None = None,
- sup_data_types: List[str] | None = None,
- sup_data_path: Path | str | None = None,
- max_duration: float | None = None,
- min_duration: float | None = None,
- ignore_file: Path | str | None = None,
- trim: bool = False,
- trim_ref: float | None = None,
- trim_top_db: int | None = None,
- trim_frame_length: int | None = None,
- trim_hop_length: int | None = None,
- n_fft: int = 1024,
- win_length: int | None = None,
- hop_length: int | None = None,
- window: str = 'hann',
- n_mels: int = 80,
- lowfreq: int = 0,
- highfreq: int | None = None,
- segment_max_duration: int | None = None,
- pitch_augment: bool = False,
- cache_pitch_augment: bool = True,
- pad_multiple: int = 1,
- **kwargs,
Bases:
Dataset
- class nemo.collections.tts.data.dataset.VocoderDataset(
- manifest_filepath: str | Path | List[str] | List[Path],
- sample_rate: int,
- n_segments: int | None = None,
- max_duration: float | None = None,
- min_duration: float | None = None,
- ignore_file: Path | str | None = None,
- trim: bool | None = False,
- load_precomputed_mel: bool = False,
- hop_length: int | None = None,
Bases:
Dataset