nemo_automodel.components.flow_matching.adapters.base#

Base classes and data structures for model adapters.

This module defines the abstract ModelAdapter class and the FlowMatchingContext dataclass used to pass data between the pipeline and adapters.

Module Contents#

Classes#

FlowMatchingContext

Context object passed to model adapters containing all necessary data.

ModelAdapter

Abstract base class for model-specific forward pass logic.

API#

class nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext#

Context object passed to model adapters containing all necessary data.

This provides a clean interface for adapters to access the data they need without coupling to the batch dictionary structure.

.. attribute:: noisy_latents

[B, C, F, H, W] or [B, C, H, W] - Noisy latents after interpolation

.. attribute:: latents

[B, C, F, H, W] for video or [B, C, H, W] for image - Original clean latents (also accessible via deprecated ‘video_latents’ property for backward compatibility)

.. attribute:: timesteps

[B] - Sampled timesteps

.. attribute:: sigma

[B] - Sigma values

.. attribute:: task_type

“t2v” or “i2v”

.. attribute:: data_type

“video” or “image”

.. attribute:: device

Device for tensor operations

.. attribute:: dtype

Data type for tensor operations

.. attribute:: cfg_dropout_prob

Probability of dropping text embeddings (setting to 0) during training for classifier-free guidance (CFG). Defaults to 0.0 for backward compatibility.

.. attribute:: batch

Original batch dictionary (for model-specific data)

noisy_latents: torch.Tensor#

None

latents: torch.Tensor#

None

timesteps: torch.Tensor#

None

sigma: torch.Tensor#

None

task_type: str#

None

data_type: str#

None

device: torch.device#

None

dtype: torch.dtype#

None

batch: Dict[str, Any]#

None

cfg_dropout_prob: float#

0.0

property video_latents: torch.Tensor#

Backward compatibility alias for ‘latents’ field.

class nemo_automodel.components.flow_matching.adapters.base.ModelAdapter#

Bases: abc.ABC

Abstract base class for model-specific forward pass logic.

Implement this class to add support for new model architectures without modifying the FlowMatchingPipeline.

The adapter pattern decouples the flow matching logic from model-specific details like input preparation and forward pass conventions.

.. rubric:: Example

class MyCustomAdapter(ModelAdapter): def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: return { “x”: context.noisy_latents, “t”: context.timesteps, “cond”: context.batch[“my_conditioning”], }

def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor:
    return model(**inputs)

pipeline = FlowMatchingPipelineV2(model_adapter=MyCustomAdapter())

abstractmethod prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext,
) Dict[str, Any]#

Prepare model-specific inputs from the context.

Parameters:

context – FlowMatchingContext containing all necessary data

Returns:

Dictionary of inputs to pass to the model’s forward method

abstractmethod forward(
model: torch.nn.Module,
inputs: Dict[str, Any],
) torch.Tensor#

Execute the model forward pass.

Parameters:
  • model – The model to call

  • inputs – Dictionary of inputs from prepare_inputs()

Returns:

Model prediction tensor

post_process_prediction(model_pred: torch.Tensor) torch.Tensor#

Post-process model prediction if needed.

Override this for models that return extra outputs or need transformation.

Parameters:

model_pred – Raw model output

Returns:

Processed prediction tensor