Skip to content

Lightning basic

This is intended to be a minimal self-container NeMo2 example.

ClassifierLossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning_basic.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class ClassifierLossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        digits = batch["label"]
        digit_logits = forward_out
        loss = nn.functional.cross_entropy(digit_logits, digits)
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Tensor

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning_basic.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    digits = batch["label"]
    digit_logits = forward_out
    loss = nn.functional.cross_entropy(digit_logits, digits)
    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning_basic.py
186
187
188
189
190
191
192
193
194
195
196
197
198
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

ExampleConfig dataclass

Bases: ExampleGenericConfig['ExampleModel', 'MSELossReduction'], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning_basic.py
350
351
352
353
354
355
356
357
358
@dataclass
class ExampleConfig(ExampleGenericConfig["ExampleModel", "MSELossReduction"], iom.IOMixinWithGettersSetters):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleModel] = ExampleModel
    loss_cls: Type[MSELossReduction] = MSELossReduction

ExampleFineTuneBothConfig dataclass

Bases: ExampleGenericConfig['ExampleFineTuneBothModel', 'MSEPlusClassifierLossReduction'], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning_basic.py
361
362
363
364
365
366
367
368
369
370
371
@dataclass
class ExampleFineTuneBothConfig(
    ExampleGenericConfig["ExampleFineTuneBothModel", "MSEPlusClassifierLossReduction"], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleFineTuneBothModel] = ExampleFineTuneBothModel
    loss_cls: Type[MSEPlusClassifierLossReduction] = MSEPlusClassifierLossReduction

ExampleFineTuneBothModel

Bases: ExampleModel

Example of taking the example model and adding an output task.

Source code in bionemo/example_model/lightning_basic.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class ExampleFineTuneBothModel(ExampleModel):
    """Example of taking the example model and adding an output task."""

    def __init__(self, config: ModelParallelConfig):
        super().__init__(config)
        # 10 output digits, and use the latent output layer (z) for making predictions
        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)

    def forward(self, x: Tensor) -> ExampleFineTuneOutput:
        parent_out: ExampleModelOutput = super().forward(x)
        digit_logits = self.digit_classifier(parent_out["z"])
        return {
            "x_hat": parent_out["x_hat"],
            "z": parent_out["z"],
            "digit_logits": digit_logits,
        }

ExampleFineTuneDropParentConfig dataclass

Bases: ExampleGenericConfig['ExampleFineTuneDropParentModel', 'ClassifierLossReduction'], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning_basic.py
374
375
376
377
378
379
380
381
382
383
384
@dataclass
class ExampleFineTuneDropParentConfig(
    ExampleGenericConfig["ExampleFineTuneDropParentModel", "ClassifierLossReduction"], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleFineTuneDropParentModel] = ExampleFineTuneDropParentModel
    loss_cls: Type[ClassifierLossReduction] = ClassifierLossReduction

ExampleFineTuneDropParentModel

Bases: ExampleModelTrunk

Example of taking the example model and replacing output task.

Source code in bionemo/example_model/lightning_basic.py
283
284
285
286
287
288
289
290
291
292
293
294
class ExampleFineTuneDropParentModel(ExampleModelTrunk):
    """Example of taking the example model and replacing output task."""

    def __init__(self, config: ModelParallelConfig):
        super().__init__(config)
        # 10 output digits, and use the latent output layer (z) for making predictions
        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)

    def forward(self, x: Tensor) -> Tensor:
        z: Tensor = super().forward(x)
        digit_logits = self.digit_classifier(z)  # to demonstrate flexibility, in this case we return a tensor
        return digit_logits

ExampleFineTuneOutput

Bases: ExampleModelOutput

Output for the fine-tuned example model implementation.

Source code in bionemo/example_model/lightning_basic.py
83
84
85
86
class ExampleFineTuneOutput(ExampleModelOutput):
    """Output for the fine-tuned example model implementation."""

    digit_logits: Tensor

ExampleGenericConfig dataclass

Bases: Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning_basic.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@dataclass
class ExampleGenericConfig(
    Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    loss_cls: Type[MegatronLossType] = MSELossReduction  # type: ignore  # this will get overriden by children
    hidden_size: int = 64  # Needs to be set to avoid zero division error in megatron :(
    num_attention_heads: int = 1  # Needs to be set to avoid zero division error in megatron :(
    num_layers: int = 1  # Needs to be set to avoid zero division error in megatron :(
    # IMPORTANT: Since we're adding/overriding the loss_cls, and that's not how we generally track this, we need to
    #   add this into the list of config settings that we do not draw from the loaded checkpoint when restoring.
    override_parent_fields: List[str] = field(default_factory=lambda: OVERRIDE_BIONEMO_CONFIG_DEFAULTS + ["loss_cls"])

    def configure_model(self) -> ExampleModelT:
        """Uses model_cls and loss_cls to configure the model.

        Note: Must pass self into Model since model requires having a config object.

        Returns:
            The model object.
        """
        # 1. first load any settings that may exist in the checkpoint related to the model.
        if self.initial_ckpt_path:
            self.load_settings_from_checkpoint(self.initial_ckpt_path)
        # 2. then initialize the model
        model = self.model_cls(self)
        # 3. Load weights from the checkpoint into the model
        if self.initial_ckpt_path:
            self.update_model_from_checkpoint(model, self.initial_ckpt_path)
        return model

    def get_loss_reduction_class(self) -> Type[MegatronLossType]:
        """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
        return self.loss_cls

configure_model()

Uses model_cls and loss_cls to configure the model.

Note: Must pass self into Model since model requires having a config object.

Returns:

Type Description
ExampleModelT

The model object.

Source code in bionemo/example_model/lightning_basic.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def configure_model(self) -> ExampleModelT:
    """Uses model_cls and loss_cls to configure the model.

    Note: Must pass self into Model since model requires having a config object.

    Returns:
        The model object.
    """
    # 1. first load any settings that may exist in the checkpoint related to the model.
    if self.initial_ckpt_path:
        self.load_settings_from_checkpoint(self.initial_ckpt_path)
    # 2. then initialize the model
    model = self.model_cls(self)
    # 3. Load weights from the checkpoint into the model
    if self.initial_ckpt_path:
        self.update_model_from_checkpoint(model, self.initial_ckpt_path)
    return model

get_loss_reduction_class()

Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config.

Source code in bionemo/example_model/lightning_basic.py
343
344
345
def get_loss_reduction_class(self) -> Type[MegatronLossType]:
    """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
    return self.loss_cls

ExampleModel

Bases: ExampleModelTrunk

Source code in bionemo/example_model/lightning_basic.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class ExampleModel(ExampleModelTrunk):  # noqa: D101
    def __init__(self, config: ModelParallelConfig) -> None:
        """Constructor of the model.

        Args:
            config: The config object is responsible for telling the strategy what model to create.
        """
        super().__init__(config)
        self.linear3 = nn.Linear(3, 64)
        self.relu2 = nn.ReLU()
        self.linear4 = nn.Linear(64, 28 * 28)

    def forward(self, x: Tensor) -> ExampleModelOutput:
        """Forward pass of the model.

        Args:
            x: The input data.

        Returns:
            x_hat: The result of the last linear layer of the network.
        """
        z: Tensor = super().forward(x)
        x_hat = self.linear3(z)
        x_hat = self.relu2(x_hat)
        x_hat = self.linear4(x_hat)
        return {"x_hat": x_hat, "z": z}

__init__(config)

Constructor of the model.

Parameters:

Name Type Description Default
config ModelParallelConfig

The config object is responsible for telling the strategy what model to create.

required
Source code in bionemo/example_model/lightning_basic.py
238
239
240
241
242
243
244
245
246
247
def __init__(self, config: ModelParallelConfig) -> None:
    """Constructor of the model.

    Args:
        config: The config object is responsible for telling the strategy what model to create.
    """
    super().__init__(config)
    self.linear3 = nn.Linear(3, 64)
    self.relu2 = nn.ReLU()
    self.linear4 = nn.Linear(64, 28 * 28)

forward(x)

Forward pass of the model.

Parameters:

Name Type Description Default
x Tensor

The input data.

required

Returns:

Name Type Description
x_hat ExampleModelOutput

The result of the last linear layer of the network.

Source code in bionemo/example_model/lightning_basic.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def forward(self, x: Tensor) -> ExampleModelOutput:
    """Forward pass of the model.

    Args:
        x: The input data.

    Returns:
        x_hat: The result of the last linear layer of the network.
    """
    z: Tensor = super().forward(x)
    x_hat = self.linear3(z)
    x_hat = self.relu2(x_hat)
    x_hat = self.linear4(x_hat)
    return {"x_hat": x_hat, "z": z}

ExampleModelOutput

Bases: TypedDict

Output for the example model implementation.

Source code in bionemo/example_model/lightning_basic.py
76
77
78
79
80
class ExampleModelOutput(TypedDict):
    """Output for the example model implementation."""

    x_hat: Tensor
    z: Tensor

ExampleModelTrunk

Bases: MegatronModule

Source code in bionemo/example_model/lightning_basic.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class ExampleModelTrunk(MegatronModule):
    def __init__(self, config: ModelParallelConfig) -> None:
        """Constructor of the model.

        Args:
            config: The config object is responsible for telling the strategy what model to create.
        """
        super().__init__(config)
        # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
        #  parallelizable megatron linear layers.
        self.model_type: ModelType = ModelType.encoder_or_decoder
        self.linear1 = nn.Linear(28 * 28, 64)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(64, 3)

    def forward(self, x: Tensor) -> Tensor:
        # we could return a dictionary of strings to tensors here, but let's demonstrate this is not necessary
        x = x.view(x.size(0), -1)
        z = self.linear1(x)
        z = self.relu(z)
        z = self.linear2(z)
        return z

    def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
        """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
        pass

__init__(config)

Constructor of the model.

Parameters:

Name Type Description Default
config ModelParallelConfig

The config object is responsible for telling the strategy what model to create.

required
Source code in bionemo/example_model/lightning_basic.py
210
211
212
213
214
215
216
217
218
219
220
221
222
def __init__(self, config: ModelParallelConfig) -> None:
    """Constructor of the model.

    Args:
        config: The config object is responsible for telling the strategy what model to create.
    """
    super().__init__(config)
    # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
    #  parallelizable megatron linear layers.
    self.model_type: ModelType = ModelType.encoder_or_decoder
    self.linear1 = nn.Linear(28 * 28, 64)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(64, 3)

set_input_tensor(input_tensor)

This would be needed for model parallel and other kinds of more complicated forward passes in megatron.

Source code in bionemo/example_model/lightning_basic.py
232
233
234
def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
    """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
    pass

LitAutoEncoder

Bases: LightningModule, IOMixin, LightningPassthroughPredictionMixin

A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.

Source code in bionemo/example_model/lightning_basic.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
class LitAutoEncoder(pl.LightningModule, io.IOMixin, LightningPassthroughPredictionMixin):
    """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract."""

    def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
        """Initializes the model.

        Args:
            config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
        """
        super().__init__()
        self.config = config
        self.optim = MegatronOptimizerModule(
            config=OptimizerConfig(
                lr=1e-4,
                optimizer="adam",
                use_distributed_optimizer=True,
                bf16=config.bf16,
                fp16=config.fp16,
                params_dtype=config.params_dtype,
            ),
        )
        # Bind the configure_optimizers method to the model
        self.optim.connect(self)

    def forward(self, batch: Dict, batch_idx: int) -> Any:
        """This forward will be called by the megatron scheduler and it will be wrapped.

        !!! note

            The `training_step` defines the training loop and is independent of the `forward` method here.

        Args:
            batch: A dictionary of data.
            batch_idx: The index of the batch.

        Returns:
            The output of the model.
        """
        x = batch["data"]
        return self.module(x)

    def training_step(self, batch, batch_idx: Optional[int] = None):
        """The training step is where the loss is calculated and the backpropagation is done.

        Background:
        - NeMo's Strategy overrides this method.
        - The strategies' training step will call the forward method of the model.
        - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
        - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
            MegatronParallel class.
        - Which then calls the training_step function here.

        In this particular use case, we simply call the forward method of this class, the lightning module.

        Args:
            batch: A dictionary of data. requires `batch_idx` as default None.
            batch_idx: The index of the batch.
        """
        return self(batch, batch_idx)

    def training_loss_reduction(self) -> MegatronLossReduction:  # noqa: D102
        # This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss
        return self.loss_reduction_class()()

    def validation_loss_reduction(self) -> MegatronLossReduction:  # noqa: D102
        return self.loss_reduction_class()()

    def test_loss_reduction(self) -> MegatronLossReduction:  # noqa: D102
        return self.loss_reduction_class()()

    def configure_model(self) -> None:  # noqa: D102
        # Called lazily by the megatron strategy.
        self.module = self.config.configure_model()

    def loss_reduction_class(self) -> Type[MegatronLossReduction]:
        """Get the loss reduction class the user has specified in their config."""
        return self.config.get_loss_reduction_class()

__init__(config)

Initializes the model.

Parameters:

Name Type Description Default
config MegatronBioNeMoTrainableModelConfig

a Config object necessary to construct the actual nn.Module (the thing that has the parameters).

required
Source code in bionemo/example_model/lightning_basic.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
    """Initializes the model.

    Args:
        config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
    """
    super().__init__()
    self.config = config
    self.optim = MegatronOptimizerModule(
        config=OptimizerConfig(
            lr=1e-4,
            optimizer="adam",
            use_distributed_optimizer=True,
            bf16=config.bf16,
            fp16=config.fp16,
            params_dtype=config.params_dtype,
        ),
    )
    # Bind the configure_optimizers method to the model
    self.optim.connect(self)

forward(batch, batch_idx)

This forward will be called by the megatron scheduler and it will be wrapped.

Note

The training_step defines the training loop and is independent of the forward method here.

Parameters:

Name Type Description Default
batch Dict

A dictionary of data.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Any

The output of the model.

Source code in bionemo/example_model/lightning_basic.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def forward(self, batch: Dict, batch_idx: int) -> Any:
    """This forward will be called by the megatron scheduler and it will be wrapped.

    !!! note

        The `training_step` defines the training loop and is independent of the `forward` method here.

    Args:
        batch: A dictionary of data.
        batch_idx: The index of the batch.

    Returns:
        The output of the model.
    """
    x = batch["data"]
    return self.module(x)

loss_reduction_class()

Get the loss reduction class the user has specified in their config.

Source code in bionemo/example_model/lightning_basic.py
467
468
469
def loss_reduction_class(self) -> Type[MegatronLossReduction]:
    """Get the loss reduction class the user has specified in their config."""
    return self.config.get_loss_reduction_class()

training_step(batch, batch_idx=None)

The training step is where the loss is calculated and the backpropagation is done.

Background: - NeMo's Strategy overrides this method. - The strategies' training step will call the forward method of the model. - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. - That wrapped forward step is then executed inside the Mcore scheduler, which calls the _forward_step method from the MegatronParallel class. - Which then calls the training_step function here.

In this particular use case, we simply call the forward method of this class, the lightning module.

Parameters:

Name Type Description Default
batch

A dictionary of data. requires batch_idx as default None.

required
batch_idx Optional[int]

The index of the batch.

None
Source code in bionemo/example_model/lightning_basic.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def training_step(self, batch, batch_idx: Optional[int] = None):
    """The training step is where the loss is calculated and the backpropagation is done.

    Background:
    - NeMo's Strategy overrides this method.
    - The strategies' training step will call the forward method of the model.
    - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
    - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
        MegatronParallel class.
    - Which then calls the training_step function here.

    In this particular use case, we simply call the forward method of this class, the lightning module.

    Args:
        batch: A dictionary of data. requires `batch_idx` as default None.
        batch_idx: The index of the batch.
    """
    return self(batch, batch_idx)

MNISTCustom

Bases: MNIST

Source code in bionemo/example_model/lightning_basic.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class MNISTCustom(MNIST):  # noqa: D101
    def __getitem__(self, index: int) -> MnistItem:
        """Wraps the getitem method of the MNIST dataset such that we return a Dict
        instead of a Tuple or tensor.

        Args:
            index: The index we want to grab, an int.

        Returns:
            A dict containing the data ("x"), label ("y"), and index ("idx").
        """  # noqa: D205
        x, y = super().__getitem__(index)

        return {
            "data": x,
            "label": y,
            "idx": index,
        }

__getitem__(index)

Wraps the getitem method of the MNIST dataset such that we return a Dict instead of a Tuple or tensor.

Parameters:

Name Type Description Default
index int

The index we want to grab, an int.

required

Returns:

Type Description
MnistItem

A dict containing the data ("x"), label ("y"), and index ("idx").

Source code in bionemo/example_model/lightning_basic.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def __getitem__(self, index: int) -> MnistItem:
    """Wraps the getitem method of the MNIST dataset such that we return a Dict
    instead of a Tuple or tensor.

    Args:
        index: The index we want to grab, an int.

    Returns:
        A dict containing the data ("x"), label ("y"), and index ("idx").
    """  # noqa: D205
    x, y = super().__getitem__(index)

    return {
        "data": x,
        "label": y,
        "idx": index,
    }

MNISTDataModule

Bases: LightningDataModule

Source code in bionemo/example_model/lightning_basic.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
class MNISTDataModule(pl.LightningDataModule):  # noqa: D101
    def __init__(self, data_dir: str = "./", batch_size: int = 32, global_batch_size: int | None = None) -> None:  # noqa: D107
        super().__init__()
        self.data_dir = data_dir
        self.micro_batch_size = batch_size
        self.global_batch_size = global_batch_size or batch_size
        self.max_len = 1048  # Unused?
        self.rampup_batch_size = None

        #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
        # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
        # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
        # the place where the global batch size is constructed.
        self.data_sampler = MegatronDataSampler(
            seq_len=self.max_len,
            micro_batch_size=self.micro_batch_size,
            global_batch_size=self.global_batch_size,
            rampup_batch_size=self.rampup_batch_size,
        )

    def setup(self, stage: str) -> None:
        """Sets up the datasets

        Args:
            stage: can be one of train / test / predict.
        """  # noqa: D415
        self.mnist_test = PRNGResampleDataset(
            MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False), seed=43
        )
        mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
        mnist_train, mnist_val = torch.utils.data.random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )
        self.mnist_train = PRNGResampleDataset(mnist_train, seed=44)
        self.mnist_val = PRNGResampleDataset(mnist_val, seed=45)

    def train_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=0)

    def val_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=0)

    def test_dataloader(self) -> DataLoader:  # noqa: D102
        return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=0)

setup(stage)

Sets up the datasets

Parameters:

Name Type Description Default
stage str

can be one of train / test / predict.

required
Source code in bionemo/example_model/lightning_basic.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def setup(self, stage: str) -> None:
    """Sets up the datasets

    Args:
        stage: can be one of train / test / predict.
    """  # noqa: D415
    self.mnist_test = PRNGResampleDataset(
        MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=False), seed=43
    )
    mnist_full = MNISTCustom(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
    mnist_train, mnist_val = torch.utils.data.random_split(
        mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
    )
    self.mnist_train = PRNGResampleDataset(mnist_train, seed=44)
    self.mnist_val = PRNGResampleDataset(mnist_val, seed=45)

MSELossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning_basic.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class MSELossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        x = batch["data"]
        x_hat = forward_out["x_hat"]
        xview = x.view(x.size(0), -1).to(x_hat.dtype)
        loss = nn.functional.mse_loss(x_hat, xview)

        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Dict[str, Tensor]

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning_basic.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    x = batch["data"]
    x_hat = forward_out["x_hat"]
    xview = x.view(x.size(0), -1).to(x_hat.dtype)
    loss = nn.functional.mse_loss(x_hat, xview)

    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning_basic.py
111
112
113
114
115
116
117
118
119
120
121
122
123
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

MSEPlusClassifierLossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning_basic.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class MSEPlusClassifierLossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        x = batch["data"]
        digits = batch["label"]
        x_hat = forward_out["x_hat"]
        digit_logits = forward_out["digit_logits"]
        xview = x.view(x.size(0), -1).to(x_hat.dtype)
        mse_loss = nn.functional.mse_loss(x_hat, xview)
        classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
        loss = classifier_loss + mse_loss
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out ExampleFineTuneOutput

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning_basic.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    x = batch["data"]
    digits = batch["label"]
    x_hat = forward_out["x_hat"]
    digit_logits = forward_out["digit_logits"]
    xview = x.view(x.size(0), -1).to(x_hat.dtype)
    mse_loss = nn.functional.mse_loss(x_hat, xview)
    classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
    loss = classifier_loss + mse_loss
    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning_basic.py
151
152
153
154
155
156
157
158
159
160
161
162
163
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

MnistItem

Bases: TypedDict

Training input for the MNIST dataset.

Source code in bionemo/example_model/lightning_basic.py
68
69
70
71
72
73
class MnistItem(TypedDict):
    """Training input for the MNIST dataset."""

    data: Tensor
    label: Tensor
    idx: int

SameSizeLossDict

Bases: TypedDict

This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.

Source code in bionemo/example_model/lightning_basic.py
62
63
64
65
class SameSizeLossDict(TypedDict):
    """This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size."""

    avg: Tensor