bridge.diffusion.models.common.dgpt_step#

DGPTStep: forward step for sbd_block_diff diffusion language model training.

Module Contents#

Classes#

DGPTStep

Forward training step for sbd_block_diff diffusion LM.

Functions#

get_batch_from_iterator

Get a batch of data from the iterator.

get_batch

Generate a batch.

_create_loss_function

_create_loss_function_sbd

_masked_loss_sbd_block_diff

Combined DLM + AR loss for sbd_block_diff training.

Data#

API#

bridge.diffusion.models.common.dgpt_step.logger#

‘getLogger(…)’

bridge.diffusion.models.common.dgpt_step.get_batch_from_iterator(
data_iterator: Iterable,
use_mtp: bool = False,
skip_getting_attention_mask_from_dataset: bool = True,
) dict[str, torch.Tensor]#

Get a batch of data from the iterator.

bridge.diffusion.models.common.dgpt_step.get_batch(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Generate a batch.

class bridge.diffusion.models.common.dgpt_step.DGPTStep(seed: int = 1234)#

Forward training step for sbd_block_diff diffusion LM.

Initialization

__call__(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, functools.partial]#
_apply_noise(tokens, labels, loss_mask, attention_mask, position_ids)#

Apply simple uniform masking and concatenate [noisy | clean] for sbd_block_diff.

bridge.diffusion.models.common.dgpt_step._create_loss_function(
loss_mask,
check_for_nan_in_loss,
check_for_spiky_loss,
)#
bridge.diffusion.models.common.dgpt_step._create_loss_function_sbd(
loss_mask,
check_for_nan_in_loss,
check_for_spiky_loss,
dlm_loss_weight=1.0,
ar_loss_weight=1.0,
)#
bridge.diffusion.models.common.dgpt_step._masked_loss_sbd_block_diff(
loss_mask: torch.Tensor,
output_tensor: Tuple[torch.Tensor, ...],
check_for_nan_in_loss: bool = True,
check_for_spiky_loss: bool = False,
dlm_loss_weight: float = 1.0,
ar_loss_weight: float = 1.0,
) tuple[torch.Tensor, torch.Tensor, dict[str, tuple[torch.Tensor, torch.Tensor]]]#

Combined DLM + AR loss for sbd_block_diff training.