nemo_automodel.components.flow_matching.adapters.simple

View as Markdown

Simple transformer model adapter for FlowMatching Pipeline.

This adapter supports simple transformer models with a basic interface, such as Wan-style models.

Module Contents

Classes

NameDescription
SimpleAdapterModel adapter for simple transformer models (e.g., Wan).

API

class nemo_automodel.components.flow_matching.adapters.simple.SimpleAdapter()

Bases: ModelAdapter

Model adapter for simple transformer models (e.g., Wan).

These models use a simple interface with:

  • hidden_states: noisy latents
  • timestep: timestep values
  • encoder_hidden_states: text embeddings

Expected batch keys:

  • text_embeddings: Text encoder output [B, seq_len, dim]
nemo_automodel.components.flow_matching.adapters.simple.SimpleAdapter.forward(
model: torch.nn.Module,
inputs: typing.Dict[str, typing.Any]
) -> torch.Tensor

Execute forward pass for simple transformer model.

Parameters:

model
nn.Module

Transformer model

inputs
Dict[str, Any]

Dictionary from prepare_inputs()

Returns: torch.Tensor

Model prediction tensor

nemo_automodel.components.flow_matching.adapters.simple.SimpleAdapter.prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext
) -> typing.Dict[str, typing.Any]

Prepare inputs for simple transformer model.

Parameters:

context
FlowMatchingContext

FlowMatchingContext with batch data

Returns: Dict[str, Any]

Dictionary containing: