nemo_automodel.components.speculative.dspark.loss

View as Markdown

Module Contents

Functions

Data

__all__

API

nemo_automodel.components.speculative.dspark.loss._all_reduce_loss_denominators(
loss_terms: dict[str, torch.Tensor],
world_size: int
) -> dict[str, torch.Tensor]
nemo_automodel.components.speculative.dspark.loss._build_loss(
loss_terms: dict[str, torch.Tensor],
global_denominators: dict[str, torch.Tensor],
ce_loss_alpha: float,
l1_loss_alpha: float,
confidence_head_alpha: float,
has_confidence: bool,
world_size: int
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.loss._build_loss_weight_mask(
eval_mask: torch.Tensor,
block_size: int,
device: torch.device,
loss_decay_gamma: typing.Optional[float]
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.loss._collect_local_terms(
outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput,
loss_decay_gamma: typing.Optional[float],
l1_loss_alpha: float
) -> tuple[dict[str, torch.Tensor], bool]
nemo_automodel.components.speculative.dspark.loss._compute_accept_rate_3d(
outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput,
aligned_target_logits: typing.Optional[torch.Tensor]
) -> typing.Optional[torch.Tensor]
nemo_automodel.components.speculative.dspark.loss._compute_local_l1_term(
outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput,
aligned_target_logits: typing.Optional[torch.Tensor],
loss_weight_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.loss.compute_dspark_loss(
outputs: nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput,
loss_decay_gamma: typing.Optional[float],
ce_loss_alpha: float,
l1_loss_alpha: float,
confidence_head_alpha: float,
return_terms: bool = False
)
nemo_automodel.components.speculative.dspark.loss.__all__ = ['compute_dspark_loss']