Cosmos-Embed1#

Cosmos-Embed1 is a video-text embedding model designed for physical AI applications. It supports the following tasks:

  • train

  • evaluate

  • inference

  • export

You can invoke these tasks using the TAO Launcher:

tao model cosmos-embed1 <action> -e /path/to/spec.yaml [overrides]

Where:

  • <action>: One of

    • train

    • evaluate

    • inference

    • export

  • overrides: Optional dot-notation key-value pairs that override configuration fields

Architecture Overview#

Cosmos-Embed1 is a dual-encoder video-text embedding model that produces a joint embedding space for videos and text. The model architecture combines:

  • EVA-ViT-G visual encoder: A large Vision Transformer (ViT) backbone that processes sampled video frames at 224 × 224 pixel resolution

  • Q-Former module: A Querying Transformer that distills visual features into a fixed number of learnable query token embeddings, initialized from a pretrained Bidirectional Encoder Representations from Transformers (BERT) checkpoint

  • Text encoder: A BERT-based text encoder that projects text captions into the same embedding space as the video encoder

  • Contrastive loss: Either CLIP (Contrastive Language-Image Pretraining)-style or SigLIP-style contrastive learning to align video and text embeddings

The model supports LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning, which significantly reduces the number of trainable parameters by adding low-rank decomposition matrices to the visual encoder and Q-Former attention layers.

Use Cases#

Cosmos-Embed1 is designed for physical AI data curation tasks that require understanding the semantic content of video collections:

  • Text-to-video retrieval: Find the most semantically relevant videos for a text query, such as “flooding on a highway” or “robotic arm picking an object”

  • Video-to-video retrieval: Find videos that are semantically similar to a query video

  • Semantic deduplication: Identify near-duplicate or highly similar videos in a large corpus by comparing their embedding vectors

  • Targeted filtering: Filter a video dataset by semantic category or anomaly type using text-based queries

Hardware Requirements#

The following hardware is required to run Cosmos-Embed1:

Minimum:

  • One NVIDIA GPU with at least 40 GB of GPU memory (for single-GPU training at 224 p resolution)

  • Ubuntu 20.04 or later

  • CUDA 12.1 or later

Recommended:

  • Multiple NVIDIA A100 or H100 GPUs for faster training

  • High-bandwidth storage for video data access

  • Multiple CPU cores for parallel data loading

Data Input for Cosmos-Embed1#

Cosmos-Embed1 uses a JSON or JSONL metadata file to describe the video dataset. Each entry in the metadata file maps a video path to a text caption (and optionally a label).

Supported Dataset Types#

Cosmos-Embed1 supports the following dataset types, configured with the dataset_type field. For the full list of configuration parameters, see the SingleDatasetConfig table. For metadata format details and required fields per dataset type, see the Dataset Format Reference.

dataset_type

Description

mock

Generates random data for testing and validation without real video files. Configured entirely through resolution and num_video_frames; no metadata file is needed

vad_r1

Loads the VAD-R1 video anomaly detection dataset from a JSONL metadata file. VAD-R1 provides Perception-to-Cognition Chain-of-Thought annotations (what, when, where, why, how) for anomalous events in surveillance videos. Requires metadata and data_root; supports path_prefix_mapping and caption_field

vad_r1_chunks

Loads VAD-R1 with temporal chunking, splitting each video into fixed-duration segments for fine-grained anomaly detection at the chunk level. Requires metadata and data_root; supports chunk_size_sec, shared_normal_label, split, and caption_field

msrvtt

Loads the MSR-VTT video description dataset. MSR-VTT provides 10K web video clips with 200K clip-sentence pairs across diverse categories. Requires mp4_urls (glob pattern) and metadata; supports split, random_caption, and caption_to_label

kinetics

Loads the Kinetics action-recognition dataset. Requires mp4_urls (glob pattern) and metadata (CSV with youtube_id and label columns)

http

Loads videos streamed from HTTP/HTTPS URLs listed in a JSON metadata file. Requires metadata with a url field per entry; supports random_caption and caption_to_label

Metadata File Format#

The metadata file is a JSON or JSONL file where each entry contains at minimum a path to a video file and a caption field. The following example shows the VAD-R1 format:

[
  {
    "path": "/data/videos/clip_001.mp4",
    "anomaly_type": "car crash",
    "split": "train",
    "total_frames": 450,
    "video_duration_sec": 15.0,
    "fps": 30.0,
    "anomaly_start_frame": 90,
    "anomaly_end_frame": 300,
    "anomaly_start_sec": 3.0,
    "anomaly_end_sec": 10.0,
    "what": "a car runs a red light and collides with another vehicle",
    "when": "[0.2, 0.67]",
    "where": "center of the intersection",
    "why": "the driver failed to stop at the red light",
    "how": "risk of injury to occupants and traffic disruption"
  }
]

The core fields describe the video and the anomaly timing:

  • path: Path to the video file on disk.

  • anomaly_type: Category label for the anomaly (e.g., “car crash”). Used as the default caption field.

  • split: Dataset split, e.g., “train”, “val”, or “test”.

  • total_frames: Total number of frames in the video.

  • video_duration_sec: Total video duration in seconds.

  • fps: Frames per second of the video.

  • anomaly_start_frame / anomaly_end_frame: Start and end frame indices of the annotated anomaly region.

  • anomaly_start_sec / anomaly_end_sec: Start and end timestamps (in seconds) of the annotated anomaly region.

The Perception-to-Cognition Chain-of-Thought (P2C-CoT) reasoning fields provide structured annotations that describe the anomaly from multiple perspectives. Any of these fields can be used as alternative caption fields via the caption_field configuration parameter:

  • what: Describes the anomalous event itself (e.g., “a car runs a red light and collides with another vehicle”).

  • when: Normalized time range of the anomaly as a proportion of the total video duration, expressed as [start, end] (e.g., [0.2, 0.67]).

  • where: Spatial location of the anomaly within the video frame (e.g., “center of the intersection”).

  • why: Explains why the event is considered anomalous (e.g., “the driver failed to stop at the red light”).

  • how: Describes the potential consequences or severity of the anomaly (e.g., “risk of injury to occupants and traffic disruption”).

For the vad_r1_chunks dataset type, each entry also contains a chunks array that defines temporal segments of the video. Each chunk specifies a time range, frame range, and whether that segment contains an anomaly:

[
  {
    "video_path": "/data/videos/clip_001.mp4",
    "anomaly_type": "Animals Obstructing Traffic",
    "split": "test",
    "video_duration_sec": 29.42,
    "fps": 30.0,
    "what": "a giraffe walks onto the road and blocks traffic",
    "chunks": [
      {
        "chunk_index": 0,
        "start_time_sec": 0.0,
        "end_time_sec": 5.0,
        "duration_sec": 5.0,
        "start_frame": 0,
        "end_frame": 150,
        "is_anomaly": true,
        "overlap_ratio": 1.0
      },
      {
        "chunk_index": 1,
        "start_time_sec": 5.0,
        "end_time_sec": 10.0,
        "duration_sec": 5.0,
        "start_frame": 150,
        "end_frame": 300,
        "is_anomaly": true,
        "overlap_ratio": 1.0
      }
    ]
  }
]

The chunk_size_sec configuration parameter controls the duration of each temporal chunk (default: 5.0 seconds). The is_anomaly field indicates whether the chunk overlaps with an annotated anomaly region, and overlap_ratio indicates the degree of overlap.

The caption_field configuration parameter controls which metadata field the model uses as the text caption. You can set it to a single field name (for example, anomaly_type) or to a list of field names (for example, ["anomaly_type", "what"]), in which case the model randomly samples from the available fields during training.

Training Cosmos-Embed1#

Training fine-tunes the video and Q-Former encoders using contrastive and captioning losses. By default, the visual encoder weights are frozen and only the Q-Former and projection heads are trained.

To Train the Model#

tao model cosmos-embed1 train \
    -e cosmos_embed1/configs/experiment_specs/finetune_224p.yaml \
    results_dir=/results/my_experiment \
    model.pretrained_model_path=/data/checkpoints/finetune_224p.pth \
    dataset.train_dataset.metadata=/data/train.json \
    dataset.train_dataset.data_root=/data/videos \
    dataset.val_dataset.metadata=/data/val.json \
    dataset.val_dataset.data_root=/data/videos

Required arguments:

  • -e: Path to the experiment specification file

  • results_dir: Directory in which to save checkpoints and logs

Optional arguments:

  • model.pretrained_model_path: Path to a pretrained checkpoint to initialize from; accepts a local file path or a HuggingFace repository ID such as nvidia/Cosmos-Embed1-224p

  • train.num_gpus: Number of GPUs to use; set to -1 to use all available GPUs; default 1

  • train.max_iter: Maximum number of training iterations; default 50000

  • train.optim.lr: Learning rate; default 1e-5

  • dataset.train_dataset.batch_size: Batch size per GPU; default 4

Low-Rank Adaptation (LoRA) enables parameter-efficient fine-tuning by adding low-rank adapter matrices to the attention and multilayer perceptron (MLP) layers of the visual encoder and Q-Former. LoRA requires setting model.network.visual_encoder.transformer_engine to false because the Parameter-Efficient Fine-Tuning (PEFT) library cannot inject adapters into Transformer Engine layers.

tao model cosmos-embed1 train \
    -e cosmos_embed1/configs/experiment_specs/finetune_224p_lora.yaml \
    results_dir=/results/my_lora_experiment \
    model.pretrained_model_path=/data/checkpoints/finetune_224p.pth \
    dataset.train_dataset.metadata=/data/train.json \
    dataset.train_dataset.data_root=/data/videos

LoRA-specific arguments:

  • model.lora.enabled: Whether to enable LoRA; default false

  • model.lora.lora_rank: Rank of the low-rank matrices; default 8

  • model.lora.lora_alpha: Scaling factor for LoRA; default 16

Multi-GPU Training#

To run training on multiple GPUs, set the train.num_gpus parameter:

tao model cosmos-embed1 train \
    -e cosmos_embed1/configs/experiment_specs/finetune_224p.yaml \
    results_dir=/results/my_experiment \
    train.num_gpus=4

To use all available GPUs automatically, set train.num_gpus=-1.

Resuming Training#

To resume training from a checkpoint, set the train.resume_training_checkpoint_path parameter:

tao model cosmos-embed1 train \
    -e cosmos_embed1/configs/experiment_specs/finetune_224p.yaml \
    results_dir=/results/my_experiment \
    train.resume_training_checkpoint_path=/results/my_experiment/train/cosmos_embed1_model_latest.pth

Training Outputs#

Training saves the following files to <results_dir>/train/:

  • cosmos_embed1_model_latest.pth: Most recent checkpoint

  • cosmos_embed1_model_<iter>.pth: Periodic checkpoints saved every checkpoint_iter iterations

  • experiment.yaml: Resolved configuration snapshot for reproducibility

  • console.log, debug.log: Training logs

  • wandb/: Weights and Biases run data (when Weights and Biases logging is enabled)

Evaluating Cosmos-Embed1#

Evaluation computes top-K hit rate classification metrics and optionally generates Uniform Manifold Approximation and Projection (UMAP) embedding visualizations on a test dataset.

To Evaluate the Model#

tao model cosmos-embed1 evaluate \
    -e cosmos_embed1/configs/experiment_specs/evaluate_224p.yaml \
    results_dir=/results/my_experiment \
    evaluate.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    dataset.test_dataset.metadata=/data/test.json \
    dataset.test_dataset.data_root=/data/videos

Required arguments:

  • -e: Path to the experiment specification file

  • evaluate.checkpoint: Path to the model checkpoint

Optional arguments:

  • evaluate.callbacks.topk_classification: Whether to compute top-K hit rate metrics; default true

  • evaluate.callbacks.embedding_visualization: Whether to generate UMAP embedding visualizations; default false

  • evaluate.callbacks.top_k_values: List of K values for top-K hit rate computation; default [1, 3, 5, 10]

  • evaluate.num_gpus: Number of GPUs for evaluation; default 1

Caching Embeddings#

For repeated evaluations against the same dataset, you can cache the model embeddings to a pickle file and reload them on subsequent runs, skipping model inference:

# First run: generate and save embeddings
tao model cosmos-embed1 evaluate -e evaluate_224p.yaml \
    evaluate.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    evaluate.save_dataset_pkl=/results/evaluate/embeddings.pkl \
    dataset.test_dataset.metadata=/data/test.json

# Subsequent runs: load cached embeddings (skips model inference)
tao model cosmos-embed1 evaluate -e evaluate_224p.yaml \
    evaluate.load_dataset_pkl=/results/evaluate/embeddings.pkl

Running Inference with Cosmos-Embed1#

Inference performs a top-K similarity search against a video corpus using text or video queries. The model encodes all videos in the search database and returns the K most similar videos to each query.

Caching the Search Database#

For repeated searches against the same video corpus, cache the encoded search database to avoid re-running model inference:

# Build and cache the search database
tao model cosmos-embed1 inference -e inference_224p.yaml \
    inference.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    dataset.inference_dataset.metadata=/data/videos.json \
    inference.save_dataset_pkl=/results/inference/db_embeddings.pkl \
    'inference.query.input_texts=["flooding"]'

# Fast subsequent searches using the cached database
tao model cosmos-embed1 inference -e inference_224p.yaml \
    inference.load_dataset_pkl=/results/inference/db_embeddings.pkl \
    'inference.query.input_texts=["a car crash", "normal traffic"]'

Exporting Cosmos-Embed1#

The export task converts a trained Cosmos-Embed1 checkpoint to Open Neural Network Exchange (ONNX) format for deployment, or to HuggingFace format for sharing and downstream use.

ONNX Export#

Cosmos-Embed1 supports exporting the video encoder, text encoder, or both encoders together:

# Export the video encoder (default)
tao model cosmos-embed1 export \
    -e cosmos_embed1/configs/experiment_specs/export_onnx_224p.yaml \
    export.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    export.mode=video

# Export the text encoder
tao model cosmos-embed1 export \
    -e cosmos_embed1/configs/experiment_specs/export_onnx_224p.yaml \
    export.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    export.mode=text

# Export both encoders in a single ONNX graph
tao model cosmos-embed1 export \
    -e cosmos_embed1/configs/experiment_specs/export_onnx_224p.yaml \
    export.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    export.mode=combined

Required arguments:

  • export.checkpoint: Path to the model checkpoint

Optional arguments:

  • export.mode: ONNX export mode; valid options: video, text, combined; default video

  • export.onnx_file: Output ONNX file path; when not set, the path is auto-derived from the checkpoint path and mode

  • export.batch_size: Batch size for export; set to -1 for a dynamic batch dimension; default 1

  • export.opset_version: ONNX opset version; default 17

  • export.simplify: Whether to apply onnxsim simplification after export; default false

HuggingFace Export#

To export a trained checkpoint to HuggingFace format (sharded safetensors and a config.json):

tao model cosmos-embed1 export \
    -e cosmos_embed1/configs/experiment_specs/export_hf_224p.yaml \
    export.checkpoint=/results/my_experiment/train/cosmos_embed1_model_latest.pth \
    export.hf_output_dir=/results/my_experiment/hf_export

Creating a Configuration File#

Cosmos-Embed1 uses YAML experiment specification files that follow the ExperimentConfig dataclass schema. You can start from one of the provided experiment specification files in cosmos_embed1/configs/experiment_specs/ and override individual fields on the command line using dot-notation:

tao model cosmos-embed1 train -e finetune_224p.yaml \
    results_dir=/my/output \
    train.max_iter=10000 \
    train.num_gpus=4 \
    dataset.train_dataset.batch_size=8

The following sections describe all available configuration parameters.

Experiment Configuration#

Parameter

Type

Description

Default

wandb

WandbConfig

Weights and Biases logging configuration. Auto-disables if no API key is found.

Weights and Biases Logging Configuration

model

ModelConfig

Model configuration.

Model Configuration

dataset

DatasetConfig

Dataset configuration.

Dataset Configuration

train

TrainConfig

Training experiment configuration.

Training Configuration

evaluate

EvaluateConfig

Evaluation experiment configuration.

Evaluation Configuration

inference

InferenceConfig

Inference experiment configuration.

Inference Configuration

export

ExportConfig

ONNX export experiment configuration.

Export Configuration

results_dir

str

Directory to save results, checkpoints, and logs.

“/results”

encryption_key

Optional[str]

Encryption key for model export (TAO compatibility).

None

model_name

str

Model name identifier.

“cosmos_embed1”

Weights and Biases Logging Configuration#

Parameter

Type

Description

Default

enable

bool

Enable Weights and Biases logging.

False

project

str

Weights and Biases project name.

“cosmos_embed1”

group

str

Run group for organizing related runs in the dashboard.

“”

name

str

Run name. Empty string auto-generates a name.

“”

tags

list[str]

List of tags for filtering runs in the dashboard.

[]

save_code

bool

Save a copy of the training code to Weights and Biases.

False

api_key

str

API key. If empty, falls back to the WANDB_API_KEY env var.

“”

Model Configuration#

Parameter

Type

Description

Default

network

NetworkConfig

Network architecture configuration.

Network Configuration

pretrained_model_path

Optional[str]

Path to a pretrained checkpoint. Accepts a local file path (.pth, .safetensors) or a HuggingFace repo ID.

None

pretrained_model_strict

bool

Strict state_dict matching when loading pretrained weights. Missing or unexpected keys raise an error when True.

True

precision

Precision

Training precision. Valid options: “bf16”, “fp16”, “fp32”.

“bf16”

input_hw

list[int]

Data-loader input resolution [H, W]. Distinct from model.network.spatial_resolution.

[224, 224]

fsdp

FSDPConfig

Fully Sharded Data Parallel configuration for distributed training.

FSDP Configuration

fsdp_shard_size

int

Legacy FSDP shard size used by the model loader.

8

lora

LoRAConfig

LoRA configuration. When enabled, wraps the network with PEFT adapters. Requires transformer_engine=False.

LoRA Configuration

Network Configuration#

Parameter

Type

Description

Default

visual_encoder

VisualEncoderConfig

Visual encoder configuration.

Visual Encoder Configuration

embed_dim

int

Output embedding dimension for video-text alignment.

256

num_query_tokens

int

Number of learnable query tokens in the Q-Former.

32

max_txt_len

int

Maximum text token sequence length.

128

num_video_frames

int

Number of input video frames.

8

spatial_resolution

list[int]

Spatial resolution [H, W] for input video frames.

[224, 224]

temporal_encoding_type

TemporalEncodingType

Type of temporal encoding. Default: “neighboring_token_propagation”.

“neighboring_token_propagation”

contrastive_type

ContrastiveType

Contrastive loss type. Valid options: “clip”, “siglip”.

“clip”

qformer_pretrain_ckpt

Optional[str]

Path or HuggingFace repo ID for the Q-Former pretrained checkpoint.

None

query_pooling_type

QueryPoolingType

Query pooling method after the Q-Former. Valid options: “avg”, “attention”, “identity”.

“avg”

pretrained_text_encoder

bool

Load pretrained BERT weights for the text encoder.

False

pretrained_visual_encoder

bool

Load pretrained weights for the visual encoder from S3 or HuggingFace.

False

num_heldout_frames

int

Number of held-out frames for certain training strategies.

0

Visual Encoder Configuration#

Parameter

Type

Description

Default

type

VisualEncoderType

Visual encoder type.

“eva_vit_g”

img_size

int

Input image size for the visual encoder.

224

pretrained

bool

Load pretrained visual encoder weights from S3.

False

use_fp8

bool

Use FP8 precision with Transformer Engine (requires transformer_engine=true).

False

transformer_engine

bool

Use Transformer Engine for optimized attention computation.

True

checkpoint_activations

bool

Use gradient checkpointing for activations to reduce memory usage.

False

checkpoint_attention

bool

Use gradient checkpointing for attention (requires transformer_engine=true).

False

FSDP Configuration#

Parameter

Type

Description

Default

enabled

bool

Enable Fully Sharded Data Parallel.

False

shard_size

Optional[int]

FSDP shard group size. None auto-selects one shard per node.

None

replica_size

Optional[int]

FSDP replica group size. None auto-selects.

None

Dataset Configuration#

Parameter

Type

Description

Default

train_dataset

SingleDatasetConfig

Training dataset configuration.

Single Dataset Configuration

val_dataset

SingleDatasetConfig

Validation dataset configuration (used during training validation).

Single Dataset Configuration

test_dataset

SingleDatasetConfig

Test/evaluation dataset configuration (used by the evaluate action).

Single Dataset Configuration

inference_dataset

SingleDatasetConfig

Inference search database configuration (used by the inference action).

Single Dataset Configuration

Single Dataset Configuration#

The following parameters apply to each of the four dataset splits: train_dataset, val_dataset, test_dataset, and inference_dataset.

Parameter

Type

Description

Default

dataset_type

DatasetType

Dataset class to use. Valid options: “mock”, “vad_r1”, “vad_r1_chunks”, “msrvtt”, “kinetics”, “http”.

“mock”

metadata

Optional[str]

Path to the metadata JSON or JSONL file.

None

data_root

Optional[str]

Root directory for video data.

None

num_video_frames

int

Number of video frames to sample from each video.

8

resolution

list[int]

Video frame resolution [H, W].

[224, 224]

batch_size

int

Batch size per GPU.

4

workers

int

Number of dataloader worker processes.

4

drop_last

bool

Drop the last incomplete batch when the dataset size is not divisible by batch_size.

True

prefetch_factor

int

Number of batches to prefetch per worker process.

2

pin_memory

bool

Pin memory buffers for faster GPU transfer.

True

split

Optional[str]

Split filter for VadR1 datasets, e.g., “train”, “test”. None means no filtering.

None

random_caption

bool

When caption_field is a list, randomly sample one field per sample instead of always using the first.

False

path_prefix_mapping

dict[str, str]

Remap video file paths, e.g., {“/old/path/”: “/new/path/”}.

{}

skip_missing_files

bool

Skip dataset entries whose video files are missing.

True

caption_field

Any

Metadata field(s) to use as captions. String or list of strings, e.g., “anomaly_type”.

“anomaly_type”

mp4_urls

Optional[str]

Glob pattern for video files used by MSRVTTDataset and KineticsDataset.

None

caption_to_label

dict[str, int]

Mapping from caption text to integer label ID.

{}

chunk_size_sec

float

Duration of each temporal chunk in seconds (VadR1ChunksDataset only).

5.0

shared_normal_label

bool

When True, all normal (non-anomaly) samples share a single label ID instead of per-class labels.

True

Dataset Format Reference#

dataset_type

Metadata Format

Entry Schema

Required Config Fields

“mock”

None

No metadata file needed. Generates random frames using resolution and num_video_frames.

“vad_r1”

JSON or JSONL

Each entry: path (video file path), anomaly_type (caption). Optional: split, start, end, total_frames, what, when, where, why, how.

metadata, data_root

“vad_r1_chunks”

JSON or JSONL

Each entry: video_path, anomaly_type. Optional: split, chunks (list of chunk dicts with start_time_sec, end_time_sec, is_anomaly).

metadata, data_root

“msrvtt”

JSON with video/caption pairs

Each entry: video_id, caption. Video files located via mp4_urls glob pattern.

mp4_urls, metadata

“kinetics”

CSV with youtube_id and label

Each row: youtube_id, label. Video files located via mp4_urls glob pattern.

mp4_urls, metadata

“http”

JSON or JSONL

Each entry: url (HTTP/HTTPS video URL), captions (list of caption strings). Optional: video_id, caption_to_label.

metadata

Training Configuration#

Parameter

Type

Description

Default

optim

OptimConfig

Optimizer configuration.

Optimizer Configuration

loss_weights

LossWeightsConfig

Per-loss weight configuration.

Loss Weights Configuration

seed

int

Random seed for reproducibility.

1234

max_iter

int

Maximum number of training iterations.

50000

num_nodes

int

Number of nodes for distributed training.

1

num_gpus

int

Number of GPUs per node. Use -1 to auto-detect all available GPUs, 0 for CPU only.

1

gpu_ids

list[int]

List of GPU device IDs to use. Overrides num_gpus for device selection.

[0]

validation_iter

int

Frequency of validation runs, in iterations.

1000

checkpoint_iter

int

Frequency of checkpoint saves, in iterations.

1000

clip_grad_norm

float

Gradient clipping norm. Set to 0.0 to disable gradient clipping.

0.0

precision

Precision

Training precision. Valid options: “bf16”, “fp16”, “fp32”.

“bf16”

resume_training_checkpoint_path

Optional[str]

Path to a checkpoint to resume training from.

None

callbacks

dict[str, Any]

Dict mapping callback name to parameter overrides. Keys must match CALLBACK_REGISTRY.

{wandb, clamp_logit_scale, …}

max_val_iter

Optional[int]

Maximum number of validation batches per GPU. None runs the full validation set.

None

freeze_visual_encoder

bool

Freeze the visual encoder weights during training.

True

use_captioning_loss

bool

Enable the captioning loss during training.

True

use_text_matching_loss

bool

Enable the text matching loss during training.

False

ema

EMAConfig

Exponential Moving Average configuration.

EMA Configuration

spectral_reparam

bool

Enable spectral reparameterization.

False

damp

DAMPConfig

DAMP (Decoupled Attention and Momentum Path) training technique configuration.

DAMP Configuration

load_training_state

bool

Restore optimizer and scheduler state when resuming training.

False

strict_resume

bool

Strict state_dict matching when resuming from a checkpoint.

False

Training Callbacks#

The callbacks field in TrainConfig is a dict mapping callback names to parameter overrides. Each key must match an entry in the CALLBACK_REGISTRY. To disable a callback, remove its key from the dict. To customize a callback, set the key to a dict of parameter overrides:

train:
  callbacks:
    gradient_clip:
      clip_norm: 5.0
    iter_speed:
      every_n: 100
    grad_norm_monitor:
      every_n: 200
      verbose: true
    # Add optional callbacks not included by default:
    validation_eval: {}

The default callbacks and their parameters are listed below. The validation_eval callback is not included by default but can be added to enable evaluation metrics during training validation.

Callback

Default Parameters

Description

“wandb”

{}

Logs training metrics to Weights and Biases.

“clamp_logit_scale”

{}

Clamps the logit scale parameter to prevent instability.

“logit_parameters_monitor”

{}

Logs logit scale and bias parameters.

“iter_speed”

every_n: 50, save_s3: False

Logs iteration throughput (samples/sec) every N iterations.

“gradient_clip”

clip_norm: 3.0

Clips gradients to a maximum L2 norm.

“grad_norm_monitor”

every_n: 500, verbose: False

Logs gradient norms every N iterations.

“spectral_norm_monitor”

every_n: 1000, verbose: True

Logs spectral norms of weight matrices every N iterations.

“ema”

{}

Updates the Exponential Moving Average model shadow weights.

“log_losses”

every_n: 50, verbose: True

Logs all loss components every N iterations.

“text_frames_visualizer”

every_n: 500

Logs video frame and text caption pairs to Weights and Biases.

“pca_feature_map_visualizer”

every_n: 500

Logs PCA-projected feature map visualizations to Weights and Biases.

“validation_eval”

{}

Runs full evaluation metrics during training validation. Not included by default; add to enable.

Optimizer Configuration#

Parameter

Type

Description

Default

optim

OptimizerType

Optimizer type. Valid options: “adamw”, “fused_adamw”, “adam”, “sgd”.

“adamw”

lr

float

Learning rate.

1e-05

weight_decay

float

Weight decay coefficient.

1e-05

betas

list[float]

Adam and AdamW beta coefficients.

[0.9, 0.98]

warmup_steps

int

Number of warmup steps for the learning rate scheduler.

1000

policy

LRPolicy

Learning rate schedule policy. Valid options: “cosine”, “linear”, “constant”.

“cosine”

lr_decay_iters

int

Number of iterations over which to decay the learning rate (cosine scheduler).

50000

Loss Weights Configuration#

Parameter

Type

Description

Default

contrastive_loss

float

Weight for the contrastive loss term.

1.0

captioning_loss

float

Weight for the captioning loss term.

1.0

matching_loss

float

Weight for the text matching loss term.

1.0

LoRA Configuration#

Parameter

Type

Description

Default

enabled

bool

Enable LoRA fine-tuning.

False

lora_rank

int

Rank of the low-rank adapter matrices. Higher rank means more trainable parameters.

8

lora_alpha

int

Alpha scaling factor for LoRA. Typically set to 2× lora_rank.

16

lora_dropout

float

Dropout probability applied to LoRA layers.

0.1

bias

LoraBias

Bias handling for LoRA. Valid options: “none”, “all”, “lora_only”.

“none”

use_rslora

bool

Use Rank-Stabilized LoRA for more stable training at higher ranks.

False

use_dora

bool

Use DoRA (Weight-Decomposed Low-Rank Adaptation).

False

target_modules

Optional[list[str]]

Module name patterns to apply LoRA to.

[“qkv”, “fc1”, “fc2”, “attn.proj”, “query”, “value”, “key”, “dense”, “vision_proj”, “text_proj”, “itm_proj”]

modules_to_save

Optional[list[str]]

Modules to keep fully trainable (bypassing LoRA).

[“temporal_encoding”, “query_pooling”]

EMA Configuration#

Parameter

Type

Description

Default

enabled

bool

Enable Exponential Moving Average weight tracking.

False

beta

float

EMA decay rate.

0.9999

DAMP Configuration#

Parameter

Type

Description

Default

enabled

bool

Enable DAMP.

False

beta

float

DAMP beta coefficient.

0.1

mode

DAMPMode

DAMP mode. Valid options: “const”, “dynamic”.

“const”

Evaluation Configuration#

Parameter

Type

Description

Default

checkpoint

Optional[str]

Path to the model checkpoint for evaluation.

None

max_val_batches

int

Maximum number of validation batches to run. -1 runs all batches.

-1

num_gpus

int

Number of GPUs for evaluation.

1

callbacks

ValidationEvalConfig

Validation evaluation callback configuration.

Validation Evaluation Callbacks Configuration

load_dataset_pkl

Optional[str]

Path to load pre-computed eval embeddings from. When set and the file exists, model inference is skipped.

None

save_dataset_pkl

Optional[str]

Path to save generated eval embeddings to. When set, embeddings are saved after generation (rank 0 only).

None

Validation Evaluation Callbacks Configuration#

Parameter

Type

Description

Default

topk_classification

bool

Enable top-K hit rate classification metrics.

True

embedding_visualization

bool

Enable UMAP embedding visualization.

False

top_k_values

list[int]

List of K values for top-K hit rate computation.

[1, 3, 5, 10]

max_eval_samples

int

Maximum number of samples to use during evaluation.

2000

Inference Configuration#

Parameter

Type

Description

Default

checkpoint

Optional[str]

Path to the model checkpoint for inference.

None

query

QueryConfig

Query inputs (text and/or video) for similarity search.

Query Configuration

num_gpus

int

Number of GPUs for inference.

1

k

int

Number of nearest-neighbor results to return per query.

5

load_dataset_pkl

Optional[str]

Path to load pre-computed search database embeddings from. When set and the file exists, model inference is skipped.

None

save_dataset_pkl

Optional[str]

Path to save generated search database embeddings to. When set, embeddings are saved after generation.

None

Query Configuration#

Parameter

Type

Description

Default

input_videos

list[str]

List of video file paths to use as queries.

[]

input_texts

list[str]

List of text strings to use as queries.

[]

Export Configuration#

Parameter

Type

Description

Default

checkpoint

Optional[str]

Path to the model checkpoint for export.

None

onnx_file

Optional[str]

Output ONNX file path. If None, the path is auto-derived from the checkpoint path and mode.

None

mode

ExportMode

Export mode. Valid options: “video”, “text”, “combined”, “huggingface”.

“video”

opset_version

int

ONNX opset version.

17

batch_size

int

Batch size for export. Set to -1 for a dynamic batch dimension.

1

on_cpu

bool

Run export on CPU instead of GPU.

False

verbose

bool

Print verbose ONNX export information.

False

simplify

bool

Apply onnxsim simplification after export.

False

hf_output_dir

Optional[str]

Output directory for HuggingFace export. If None, auto-derived from checkpoint path. Only used when mode=huggingface.

None