Configuration Files#

The SpeechLM2 models use YAML configuration files to define model architecture, training parameters, and data settings. This page describes the configuration structure and important parameters for each model type in the collection.

Configuration Structure#

SpeechLM2 configuration files typically have the following high-level structure:

model:
  # Model architecture settings
  ...

trainer:
  # PyTorch Lightning trainer settings
  ...

exp_manager:
  # Experiment logging settings
  ...

data:
  # Dataset settings
  ...

SALM Configuration#

The SALM (Speech-Augmented Language Model) configuration includes settings for the pretrained LLM, audio perception module, and training parameters. See the `SALM paper<https://arxiv.org/abs/2310.09424>`_ for more details.

model:
  # Pretrained model paths
  pretrained_llm: "TinyLlama/TinyLlama_v1.1"  # HF model path
  pretrained_asr: "stt_en_fastconformer_hybrid_large_streaming_80ms"  # NeMo checkpoint name
  pretrained_weights: True  # Whether to load weights or just architecture

  # Special token settings
  audio_locator_tag: "<audio>"  # Tag to replace with audio embeddings

  # Freezing parameters
  freeze_params:
    - "^llm\\.model\\.layers\\.[0-4]\\..+$"  # Regex patterns for parameters to freeze
  prevent_freeze_params: []  # Override freeze_params for specific submodules

  # Optional LoRA settings for efficient fine-tuning
  lora:
    task_type: CAUSAL_LM
    r: 8
    lora_alpha: 32
    lora_dropout: 0.1

  # Audio perception module configuration
  perception:
    target: nemo.collections.speechlm2.modules.perception.AudioPerceptionModule

    preprocessor:
      normalize: 'NA'

    encoder:
      self_attention_model: rel_pos
      att_context_size: [-1, -1]
      conv_context_size: regular
      conv_norm_type: batch_norm

    modality_adapter:
      _target_: nemo.collections.asr.modules.ConformerEncoder
      feat_in: 1024
      feat_out: -1
      n_layers: 2
      d_model: 1024
      subsampling: dw_striding
      subsampling_factor: 1
      subsampling_conv_channels: 256
      causal_downsampling: false
      ff_expansion_factor: 4
      self_attention_model: rel_pos
      n_heads: 8
      att_context_size: [-1, -1]
      att_context_style: regular
      xscaling: true
      untie_biases: true
      pos_emb_max_len: 5000
      conv_kernel_size: 9
      conv_norm_type: batch_norm
      conv_context_size: null
      dropout: 0
      dropout_pre_encoder: 0
      dropout_emb: 0.0

DuplexS2SModel Configuration#

The DuplexS2SModel adds speech generation capabilities to the configuration:

model:
  # Pretrained model paths
  pretrained_llm: "TinyLlama/TinyLlama_v1.1"
  pretrained_audio_codec: "path/to/audio_codec.nemo"
  pretrained_asr: "stt_en_fastconformer_hybrid_large_streaming_80ms"
  scoring_asr: "stt_en_fastconformer_transducer_large"  # used only in validation

  # Loss weights
  audio_loss_weight: 4
  text_loss_weight: 3

  # Perception module config (similar to SALM)
  perception:
    # ... (similar to SALM perception module)

DuplexS2SSpeechDecoderModel Configuration#

The DuplexS2SSpeechDecoderModel is similar to DuplexS2SModel, but focuses on an additional speech generation transformer decoder:

model:
  # Pretrained model paths
  pretrained_llm: "TinyLlama/TinyLlama_v1.1"
  pretrained_audio_codec: "path/to/audio_codec.nemo"
  pretrained_asr: "stt_en_fastconformer_hybrid_large_streaming_80ms"

  # Speech decoder settings
  speech_decoder:
    target: nemo.collections.speechlm2.modules.speech_generation.TransformerARSpeechDecoder
    d_model: 1024
    n_layers: 12
    n_heads: 16
    d_kv: 64
    d_ff: 4096
    max_seq_len: 2048
    dropout: 0.1
    layernorm_epsilon: 1e-5
    activation_function: "gelu_new"
    init_method_std: 0.02
    use_cache: True

  # ... other settings

Trainer Configuration#

The trainer section contains PyTorch Lightning Trainer settings:

trainer:
  devices: 1
  num_nodes: 1
  accelerator: gpu
  precision: bf16-true
  logger: false
  enable_checkpointing: false  # handled by exp_manager
  replace_sampler_ddp: false   # handled by lhotse
  max_epochs: null
  max_steps: 100000
  log_every_n_steps: 10
  val_check_interval: 2000
  accumulate_grad_batches: 1
  gradient_clip_val: 1.0

Experiment Manager Configuration#

The exp_manager section contains settings for experiment logging and model checkpointing:

exp_manager:
  explicit_log_dir: path/to/output_dir
  exp_dir: null
  name: ${name}
  create_wandb_logger: false  # set to true if you want to use wandb
  wandb_logger_kwargs:
    project: null
    name: null
  resume_if_exists: true
  resume_ignore_no_checkpoint: true
  create_checkpoint_callback: true
  checkpoint_callback_params:
    monitor: val_loss
    filename: "{step}"  # checkpoint name will be step=<step>.ckpt
    save_top_k: 1
    mode: min
  create_tensorboard_logger: false  # set to true if you want to use tensorboard
  version: null

Data Configuration#

The data section defines dataset paths, preprocessing, and data loading parameters:

data:
  train_ds:
    sample_rate: ${data.target_sample_rate}
    input_cfg:
      - type: lhotse_shar
        shar_path: /path/to/train_data
    seed: 42
    shard_seed: "randomized"
    num_workers: 4
    batch_size: 16
    # Optional bucketing settings
    # batch_duration: 100
    # bucket_duration_bins: [8.94766,10.1551,11.64118,19.30376,42.85]
    # use_bucketing: true
    # num_buckets: 5
    # bucket_buffer_size: 5000

  validation_ds:
    datasets:
      val_set_name:
        shar_path: /path/to/validation_data
    sample_rate: ${data.target_sample_rate}
    batch_size: 1
    seed: 42
    shard_seed: "randomized"

Depending on the model, there may be additional options available under data namespace that are passed to the corresponding Dataset class. For example, S2S models have:

data:
  frame_length: 0.08
  source_sample_rate: 16000
  target_sample_rate: 22050
  input_roles: ["user", "User"]
  output_roles: ["agent", "Assistant"]

  train_ds: ...

Important Configuration Parameters#

Model Parameters#

  • pretrained_llm: Path to the pretrained HuggingFace LLM

  • pretrained_asr: Name of the pretrained NeMo ASR model used for perception

  • pretrained_audio_codec: Path to the pretrained audio codec model (for speech generation)

  • freeze_params: Regex patterns of parameters to freeze during training

  • audio_loss_weight/text_loss_weight: Weighting of different loss components

Perception Module#

  • self_attention_model: Type of attention mechanism (“rel_pos” or “abs_pos”)

  • att_context_size: Context window size for attention ([left, right])

  • conv_context_size: Context type for convolutions (“causal” or “regular”)

  • n_layers: Number of encoder layers

  • d_model: Model dimension size

Data Parameters#

  • frame_length: Frame duration in seconds

  • source_sample_rate/target_sample_rate: Sample rates for input/output audio

  • input_roles/output_roles: Speaker roles for input and output

  • batch_size: Number of samples per batch

  • use_bucketing: Whether to use length-based bucketing for efficient batching

Example Configuration Files#

Example configurations for all model types can be found in the example directory:

  • SALM: examples/speechlm2/conf/salm.yaml

  • DuplexS2SModel: examples/speechlm2/conf/s2s_duplex.yaml

  • DuplexS2SSpeechDecoderModel: examples/speechlm2/conf/s2s_duplex_speech_decoder.yaml

Using Configuration Files#

You can use these configurations with the training scripts by specifying the config path:

# Train SALM model
python examples/speechlm2/salm_train.py \
  --config-path=conf \
  --config-name=salm

You can also override configuration values from the command line:

python examples/speechlm2/salm_train.py \
  --config-path=conf \
  --config-name=salm \
  model.pretrained_llm="different/llm/path" \
  trainer.max_steps=1000 \
  data.train_ds.batch_size=8