DevelopmentAPI ReferenceFull Library ReferenceNemo AutomodelNemo AutomodelComponentsSpeculativeDspark
nemo_automodel.components.speculative.dspark.loss
nemo_automodel.components.speculative.dspark.loss
Module Contents
Functions
Data
API
nemo_automodel.components.speculative.dspark.loss._all_reduce_loss_denominators( loss_terms: dict[str, torch.Tensor], world_size: int ) -> dict[str, torch.Tensor]
nemo_automodel.components.speculative.dspark.loss._build_loss( loss_terms: dict[str, torch.Tensor], global_denominators: dict[str, torch.Tensor], ce_loss_alpha: float, l1_loss_alpha: float, confidence_head_alpha: float, has_confidence: bool, world_size: int ) -> torch.Tensor
nemo_automodel.components.speculative.dspark.loss._build_loss_weight_mask( eval_mask: torch.Tensor, block_size: int, device: torch.device, loss_decay_gamma: typing.Optional[float] ) -> torch.Tensor
nemo_automodel.components.speculative.dspark.loss._collect_local_terms( outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput, loss_decay_gamma: typing.Optional[float], l1_loss_alpha: float ) -> tuple[dict[str, torch.Tensor], bool]
nemo_automodel.components.speculative.dspark.loss._compute_accept_rate_3d( outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput, aligned_target_logits: typing.Optional[torch.Tensor] ) -> typing.Optional[torch.Tensor]
nemo_automodel.components.speculative.dspark.loss._compute_local_l1_term( outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput, aligned_target_logits: typing.Optional[torch.Tensor], loss_weight_mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.loss.compute_dspark_loss( outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput, loss_decay_gamma: typing.Optional[float], ce_loss_alpha: float, l1_loss_alpha: float, confidence_head_alpha: float, return_terms: bool = False )
nemo_automodel.components.speculative.dspark.loss.__all__ = ['compute_dspark_loss']