nemo_automodel.components.flow_matching.adapters.base#
Base classes and data structures for model adapters.
This module defines the abstract ModelAdapter class and the FlowMatchingContext dataclass used to pass data between the pipeline and adapters.
Module Contents#
Classes#
Context object passed to model adapters containing all necessary data. |
|
Abstract base class for model-specific forward pass logic. |
API#
- class nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext#
Context object passed to model adapters containing all necessary data.
This provides a clean interface for adapters to access the data they need without coupling to the batch dictionary structure.
.. attribute:: noisy_latents
[B, C, F, H, W] or [B, C, H, W] - Noisy latents after interpolation
.. attribute:: latents
[B, C, F, H, W] for video or [B, C, H, W] for image - Original clean latents (also accessible via deprecated ‘video_latents’ property for backward compatibility)
.. attribute:: timesteps
[B] - Sampled timesteps
.. attribute:: sigma
[B] - Sigma values
.. attribute:: task_type
“t2v” or “i2v”
.. attribute:: data_type
“video” or “image”
.. attribute:: device
Device for tensor operations
.. attribute:: dtype
Data type for tensor operations
.. attribute:: cfg_dropout_prob
Probability of dropping text embeddings (setting to 0) during training for classifier-free guidance (CFG). Defaults to 0.0 for backward compatibility.
.. attribute:: batch
Original batch dictionary (for model-specific data)
- noisy_latents: torch.Tensor#
None
- latents: torch.Tensor#
None
- timesteps: torch.Tensor#
None
- sigma: torch.Tensor#
None
- task_type: str#
None
- data_type: str#
None
- device: torch.device#
None
- dtype: torch.dtype#
None
- batch: Dict[str, Any]#
None
- cfg_dropout_prob: float#
0.0
- property video_latents: torch.Tensor#
Backward compatibility alias for ‘latents’ field.
- class nemo_automodel.components.flow_matching.adapters.base.ModelAdapter#
Bases:
abc.ABCAbstract base class for model-specific forward pass logic.
Implement this class to add support for new model architectures without modifying the FlowMatchingPipeline.
The adapter pattern decouples the flow matching logic from model-specific details like input preparation and forward pass conventions.
.. rubric:: Example
class MyCustomAdapter(ModelAdapter): def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: return { “x”: context.noisy_latents, “t”: context.timesteps, “cond”: context.batch[“my_conditioning”], }
def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: return model(**inputs)pipeline = FlowMatchingPipelineV2(model_adapter=MyCustomAdapter())
- abstractmethod prepare_inputs( ) Dict[str, Any]#
Prepare model-specific inputs from the context.
- Parameters:
context – FlowMatchingContext containing all necessary data
- Returns:
Dictionary of inputs to pass to the model’s forward method
- abstractmethod forward(
- model: torch.nn.Module,
- inputs: Dict[str, Any],
Execute the model forward pass.
- Parameters:
model – The model to call
inputs – Dictionary of inputs from prepare_inputs()
- Returns:
Model prediction tensor
- post_process_prediction(model_pred: torch.Tensor) torch.Tensor#
Post-process model prediction if needed.
Override this for models that return extra outputs or need transformation.
- Parameters:
model_pred – Raw model output
- Returns:
Processed prediction tensor