bridge.diffusion.models.common.dgpt_step#
DGPTStep: forward step for sbd_block_diff diffusion language model training.
Module Contents#
Classes#
Forward training step for sbd_block_diff diffusion LM. |
Functions#
Get a batch of data from the iterator. |
|
Generate a batch. |
|
Combined DLM + AR loss for sbd_block_diff training. |
Data#
API#
- bridge.diffusion.models.common.dgpt_step.logger#
‘getLogger(…)’
- bridge.diffusion.models.common.dgpt_step.get_batch_from_iterator(
- data_iterator: Iterable,
- use_mtp: bool = False,
- skip_getting_attention_mask_from_dataset: bool = True,
Get a batch of data from the iterator.
- bridge.diffusion.models.common.dgpt_step.get_batch(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Generate a batch.
- class bridge.diffusion.models.common.dgpt_step.DGPTStep(seed: int = 1234)#
Forward training step for sbd_block_diff diffusion LM.
Initialization
- __call__(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- return_schedule_plan: bool = False,
- _apply_noise(tokens, labels, loss_mask, attention_mask, position_ids)#
Apply simple uniform masking and concatenate [noisy | clean] for sbd_block_diff.
- bridge.diffusion.models.common.dgpt_step._create_loss_function(
- loss_mask,
- check_for_nan_in_loss,
- check_for_spiky_loss,
- bridge.diffusion.models.common.dgpt_step._create_loss_function_sbd(
- loss_mask,
- check_for_nan_in_loss,
- check_for_spiky_loss,
- dlm_loss_weight=1.0,
- ar_loss_weight=1.0,
- bridge.diffusion.models.common.dgpt_step._masked_loss_sbd_block_diff(
- loss_mask: torch.Tensor,
- output_tensor: Tuple[torch.Tensor, ...],
- check_for_nan_in_loss: bool = True,
- check_for_spiky_loss: bool = False,
- dlm_loss_weight: float = 1.0,
- ar_loss_weight: float = 1.0,
Combined DLM + AR loss for sbd_block_diff training.