nemo_automodel.components.flow_matching.adapters.base

View as Markdown

Base classes and data structures for model adapters.

This module defines the abstract ModelAdapter class and the FlowMatchingContext dataclass used to pass data between the pipeline and adapters.

Module Contents

Classes

NameDescription
FlowMatchingContextContext object passed to model adapters containing all necessary data.
ModelAdapterAbstract base class for model-specific forward pass logic.

API

class nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext(
noisy_latents: torch.Tensor,
latents: torch.Tensor,
timesteps: torch.Tensor,
sigma: torch.Tensor,
task_type: str,
data_type: str,
device: torch.device,
dtype: torch.dtype,
batch: typing.Dict[str, typing.Any],
cfg_dropout_prob: float = 0.0
)
Dataclass

Context object passed to model adapters containing all necessary data.

This provides a clean interface for adapters to access the data they need without coupling to the batch dictionary structure.

batch
Dict[str, Any]
cfg_dropout_prob
float = 0.0
data_type
str
device
device
dtype
dtype
latents
Tensor
noisy_latents
Tensor
sigma
Tensor
task_type
str
timesteps
Tensor
video_latents
Tensor

Backward compatibility alias for ‘latents’ field.

class nemo_automodel.components.flow_matching.adapters.base.ModelAdapter()
Abstract

Abstract base class for model-specific forward pass logic.

Implement this class to add support for new model architectures without modifying the FlowMatchingPipeline.

The adapter pattern decouples the flow matching logic from model-specific details like input preparation and forward pass conventions.

nemo_automodel.components.flow_matching.adapters.base.ModelAdapter.forward(
model: torch.nn.Module,
inputs: typing.Dict[str, typing.Any]
) -> torch.Tensor
abstract

Execute the model forward pass.

Parameters:

model
nn.Module

The model to call

inputs
Dict[str, Any]

Dictionary of inputs from prepare_inputs()

Returns: torch.Tensor

Model prediction tensor

nemo_automodel.components.flow_matching.adapters.base.ModelAdapter.post_process_prediction(
model_pred: torch.Tensor
) -> torch.Tensor

Post-process model prediction if needed.

Override this for models that return extra outputs or need transformation.

Parameters:

model_pred
torch.Tensor

Raw model output

Returns: torch.Tensor

Processed prediction tensor

nemo_automodel.components.flow_matching.adapters.base.ModelAdapter.prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext
) -> typing.Dict[str, typing.Any]
abstract

Prepare model-specific inputs from the context.

Parameters:

context
FlowMatchingContext

FlowMatchingContext containing all necessary data

Returns: Dict[str, Any]

Dictionary of inputs to pass to the model’s forward method