nemo_automodel.components.datasets.multimodal.collate_fns

View as Markdown

Multimodal collate functions.

BAGEL uses packed sequences (samples concatenated along the sequence axis with a cumulative-seqlens index), not left/right padding. The collate function is essentially a pass-through that wraps the single packed dict produced by :class:PackedDataset in a SimpleCustomBatch with pin_memory / cuda helpers.

Module Contents

Classes

NameDescription
SimpleCustomBatchPass-through wrapper around one packed batch from :class:PackedDataset.

Functions

NameDescription
bagel_packed_collate_fnCanonical name in AM’s collate-fn registry.
collate_wrapperReturn the BAGEL-style identity collate (wraps a single packed dict).

API

class nemo_automodel.components.datasets.multimodal.collate_fns.SimpleCustomBatch(
batch
)

Pass-through wrapper around one packed batch from :class:PackedDataset.

attn_modes
= data['attn_modes']
batch_data_indexes
= data['batch_data_indexes']
ce_loss_indexes
= data['ce_loss_indexes']
ce_loss_weights
= data['ce_loss_weights']
mse_loss_indexes
= data['mse_loss_indexes']
nested_attention_masks
= data['nested_attention_masks']
packed_label_ids
= data['packed_label_ids']
packed_latent_position_ids
= data['packed_latent_position_ids']
packed_position_ids
= data['packed_position_ids']
packed_text_ids
= data['packed_text_ids']
packed_text_indexes
= data['packed_text_indexes']
packed_timesteps
= data['packed_timesteps']
packed_vae_token_indexes
= data['packed_vae_token_indexes']
packed_vit_position_ids
= data['packed_vit_position_ids']
packed_vit_token_indexes
= data['packed_vit_token_indexes']
packed_vit_tokens
= data['packed_vit_tokens']
padded_images
= data['padded_images']
patchified_vae_latent_shapes
= data['patchified_vae_latent_shapes']
sample_lens
= data['sample_lens']
sequence_length
= data['sequence_length']
split_lens
= data['split_lens']
use_flex
= 'nested_attention_masks' not in data.keys()
vit_token_seqlens
= data['vit_token_seqlens']
nemo_automodel.components.datasets.multimodal.collate_fns.SimpleCustomBatch.cuda(
device
)
nemo_automodel.components.datasets.multimodal.collate_fns.SimpleCustomBatch.pin_memory()
nemo_automodel.components.datasets.multimodal.collate_fns.SimpleCustomBatch.to_dict()
nemo_automodel.components.datasets.multimodal.collate_fns.bagel_packed_collate_fn(
batch
)

Canonical name in AM’s collate-fn registry.

nemo_automodel.components.datasets.multimodal.collate_fns.collate_wrapper()

Return the BAGEL-style identity collate (wraps a single packed dict).