Skip to content

Lightning

DataStep = Callable[[Iterator[DataT]], DataT] module-attribute

Batches together an iterator of individual examples.

Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.

A DataStep function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors, or a set of named tensors (provided as a dict mapping str names to each Tensor). Each iteration must yield the same type.

The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all of the examples in the iterator.

ForwardStep = Callable[[MegatronModelType, DataT], DataT] module-attribute

Megatron-compatible forward pass function.

BionemoLightningModule

Bases: Generic[MegatronModelType, MegatronLossType], LightningModule, IOMixin, ConnectorMixin, LightningPassthroughPredictionMixin

Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.

Source code in bionemo/llm/lightning.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class BionemoLightningModule(
    Generic[MegatronModelType, MegatronLossType],
    pl.LightningModule,
    nlio.IOMixin,
    nlio.ConnectorMixin,
    LightningPassthroughPredictionMixin,
):
    """Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions."""

    def __init__(
        self,
        config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
        forward_step: ForwardStep,
        data_step: DataStep,
        # TODO: Add transformer_layer_spec when we update mcore
        optimizer: MegatronOptimizerModule,
        model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
        **model_construct_args,
    ) -> None:
        """Constructor.

        Args:
            config: Serializable configuration object that allows one to construct a new model instance and loss
                function. Necessary for Megatron-based training as the model itself cannot be serialized and
                distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
            forward_step: Performs forward pass using the model and a batch of data.
            data_step: Custom batch-creating function for the model.
            optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
                rate.
            model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
                `configure_model` method.
            model_transform: Optional. The model transform function.
            **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
                `configure_model` method, which will make an instance of the model.
        """
        super().__init__()
        self.config = config
        self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
        # ***must** be set up in configure_model() -- megatron constraint
        # also, must be called `module`: nemo expects the actual model to be stored this way
        self.module: Optional[MegatronModelType] = None
        self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
        self.optim = optimizer
        self.optim.connect(self)  # This will bind the `configure_optimizers` method
        self._data_step = data_step
        self._forward_step = forward_step
        self.model_transform = model_transform

    def configure_model(self) -> None:
        """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

        NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

        Raises:
            ValueError iff the internal config's configure_model method returns None.
        """
        if self.module is None:
            model: MegatronModelType = (
                self.config.configure_model(**self.module_construct_args)
                if self.module_construct_args is not None
                else self.config.configure_model()
            )
            self.module = model
        if self.module is None:
            raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

    def forward(self, *args, **kwargs) -> DataT:
        """Call the forward method of the underlying model, and return whatever it outputs."""
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
        prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
        return prediction

    def data_step(self, dataloader_iter: Iterator[DataT]) -> DataT:  # noqa: D102
        return self._data_step(dataloader_iter)

    def forward_step(self, batch) -> Tensor:
        """Megatron-required: the training forward step for the model, which is required to produce the loss.

        Normally, the forward pass of a model means its inference. Loss is computed using the predictions
        from the forward pass against labels. Megatron unfortunately conflates these two different concepts
        and instead has models "forward" method produce the loss. See the Megatron docs for details:
        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

        To get actual predictions, use the :func:`forward` method instead.
        """
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None
        return self._forward_step(self.module, batch)

    def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        return self.forward_step(batch)

    def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        return self.forward_step(batch)

    def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """Alias for forward_step."""
        return self.forward_step(batch)

    def training_loss_reduction(self) -> MegatronLossType:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
        return self.loss_reduction_class()

    def validation_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def test_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

__init__(config, forward_step, data_step, optimizer, model_transform=None, **model_construct_args)

Constructor.

Parameters:

Name Type Description Default
config BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]

Serializable configuration object that allows one to construct a new model instance and loss function. Necessary for Megatron-based training as the model itself cannot be serialized and distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.

required
forward_step ForwardStep

Performs forward pass using the model and a batch of data.

required
data_step DataStep

Custom batch-creating function for the model.

required
optimizer MegatronOptimizerModule

Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning rate.

required
model_construct_args

Optional. Any arguments necessary to construct the model in the config's configure_model method.

{}
model_transform Optional[Callable[[MegatronModelType], MegatronModelType]]

Optional. The model transform function.

None
**model_construct_args

Optional. Arguments necessary for the supplied model configuration's configure_model method, which will make an instance of the model.

{}
Source code in bionemo/llm/lightning.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def __init__(
    self,
    config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    forward_step: ForwardStep,
    data_step: DataStep,
    # TODO: Add transformer_layer_spec when we update mcore
    optimizer: MegatronOptimizerModule,
    model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
    **model_construct_args,
) -> None:
    """Constructor.

    Args:
        config: Serializable configuration object that allows one to construct a new model instance and loss
            function. Necessary for Megatron-based training as the model itself cannot be serialized and
            distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
        forward_step: Performs forward pass using the model and a batch of data.
        data_step: Custom batch-creating function for the model.
        optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
            rate.
        model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
            `configure_model` method.
        model_transform: Optional. The model transform function.
        **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
            `configure_model` method, which will make an instance of the model.
    """
    super().__init__()
    self.config = config
    self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
    # ***must** be set up in configure_model() -- megatron constraint
    # also, must be called `module`: nemo expects the actual model to be stored this way
    self.module: Optional[MegatronModelType] = None
    self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
    self.optim = optimizer
    self.optim.connect(self)  # This will bind the `configure_optimizers` method
    self._data_step = data_step
    self._forward_step = forward_step
    self.model_transform = model_transform

configure_model()

Updates internal state: instantiates the model from the object's config, assigns to model attribute.

NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

Source code in bionemo/llm/lightning.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def configure_model(self) -> None:
    """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

    NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

    Raises:
        ValueError iff the internal config's configure_model method returns None.
    """
    if self.module is None:
        model: MegatronModelType = (
            self.config.configure_model(**self.module_construct_args)
            if self.module_construct_args is not None
            else self.config.configure_model()
        )
        self.module = model
    if self.module is None:
        raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

forward(*args, **kwargs)

Call the forward method of the underlying model, and return whatever it outputs.

Source code in bionemo/llm/lightning.py
250
251
252
253
254
255
256
def forward(self, *args, **kwargs) -> DataT:
    """Call the forward method of the underlying model, and return whatever it outputs."""
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
    prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
    return prediction

forward_step(batch)

Megatron-required: the training forward step for the model, which is required to produce the loss.

Normally, the forward pass of a model means its inference. Loss is computed using the predictions from the forward pass against labels. Megatron unfortunately conflates these two different concepts and instead has models "forward" method produce the loss. See the Megatron docs for details: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

To get actual predictions, use the :func:forward method instead.

Source code in bionemo/llm/lightning.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def forward_step(self, batch) -> Tensor:
    """Megatron-required: the training forward step for the model, which is required to produce the loss.

    Normally, the forward pass of a model means its inference. Loss is computed using the predictions
    from the forward pass against labels. Megatron unfortunately conflates these two different concepts
    and instead has models "forward" method produce the loss. See the Megatron docs for details:
    https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

    To get actual predictions, use the :func:`forward` method instead.
    """
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None
    return self._forward_step(self.module, batch)

predict_step(batch, batch_idx=None)

Alias for forward_step.

Source code in bionemo/llm/lightning.py
284
285
286
def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """Alias for forward_step."""
    return self.forward_step(batch)

training_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Source code in bionemo/llm/lightning.py
288
289
290
def training_loss_reduction(self) -> MegatronLossType:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
    return self.loss_reduction_class()

training_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
276
277
278
def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    return self.forward_step(batch)

validation_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
280
281
282
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    return self.forward_step(batch)

LightningPassthroughPredictionMixin

A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.

Source code in bionemo/llm/lightning.py
158
159
160
161
162
163
class LightningPassthroughPredictionMixin:
    """A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism."""

    def predict_loss_reduction(self) -> PassthroughLossReduction:
        """For the predict step, pass through the forward pass output."""
        return PassthroughLossReduction()

predict_loss_reduction()

For the predict step, pass through the forward pass output.

Source code in bionemo/llm/lightning.py
161
162
163
def predict_loss_reduction(self) -> PassthroughLossReduction:
    """For the predict step, pass through the forward pass output."""
    return PassthroughLossReduction()

PassthroughLossReduction

Bases: MegatronLossReduction, Generic[DataT]

A workaround for nemo/megatron to perform inference.

Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of tensors. The inner type must always be a Tensor.

Source code in bionemo/llm/lightning.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class PassthroughLossReduction(MegatronLossReduction, Generic[DataT]):
    """A workaround for nemo/megatron to perform inference.

    Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is
    expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed
    as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of
    forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of
    tensors. The inner type _must always be a Tensor_.
    """

    def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
        """Passes through the `forward_out` value as the 2nd tuple element.

        Args:
            batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
            forward_out: The output from your model's forward pass.

        Returns:
            A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
        """
        dtype, device = get_dtype_device(forward_out)
        return torch.zeros(1, device=device, dtype=dtype), forward_out

    def reduce(self, forward_out: List[DataT]) -> DataT:
        """Collates list of model's outputs into a single output."""
        return batch_collator(forward_out)

forward(batch, forward_out)

Passes through the forward_out value as the 2nd tuple element.

Parameters:

Name Type Description Default
batch DataT

The batch of data that was passed through the model to generate output. NOTE: this value is ignored.

required
forward_out DataT

The output from your model's forward pass.

required

Returns:

Type Description
Tuple[Tensor, DataT]

A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).

Source code in bionemo/llm/lightning.py
140
141
142
143
144
145
146
147
148
149
150
151
def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
    """Passes through the `forward_out` value as the 2nd tuple element.

    Args:
        batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
        forward_out: The output from your model's forward pass.

    Returns:
        A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
    """
    dtype, device = get_dtype_device(forward_out)
    return torch.zeros(1, device=device, dtype=dtype), forward_out

reduce(forward_out)

Collates list of model's outputs into a single output.

Source code in bionemo/llm/lightning.py
153
154
155
def reduce(self, forward_out: List[DataT]) -> DataT:
    """Collates list of model's outputs into a single output."""
    return batch_collator(forward_out)

PerplexityLoggingCallback

Bases: Callback, CallbackMethods

Megatron Callback to log perplexity in validation and optionally training.

NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.

Source code in bionemo/llm/lightning.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class PerplexityLoggingCallback(pl.Callback, CallbackMethods):
    """Megatron Callback to log perplexity in validation and optionally training.

    NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
    """

    def __init__(self, log_train: bool = False, log_val: bool = True):
        """Initialize PerplexityLoggingCallback.

        Args:
            log_train: whether to log train perplexity. Defaults to False.
            log_val: whether to log validation perplexity. Defaults to True.
        """
        super().__init__()
        self.log_train = log_train
        self.log_val = log_val

    def _pad_to_max_length(
        self, microbatch_outputs: List[Dict[str, Dict[str, Tensor]]], key1: str, key2: str, pad_value: int = 0
    ) -> Tensor:
        """Pad tensors to max length in microbatch_outputs."""
        max_sequence_length: int = max(output[key1][key2].size(1) for output in microbatch_outputs)

        tensors: List[Tensor] = []
        for microbatch_output in microbatch_outputs:
            tensor = microbatch_output[key1][key2]
            assert (
                tensor.dim() >= 2
            ), f"Tensor in microbatch_outputs must have at least 2 dimensions, but got {tensor.dim()} dimensions"
            tensors.append(
                torch.nn.functional.pad(  # padding reverse in order
                    tensor,
                    (0, 0) * (tensor.dim() - 2)
                    + (0, max_sequence_length - tensor.shape[1], 0, 0),  # [b s *] -> [* s b]
                    value=pad_value,
                )
            )

        return torch.cat(tensors, dim=0)  # concat on batch dim

    @override
    def on_megatron_reduce_microbatches_end(
        self,
        step: MegatronStep,
        microbatch_outputs: List[Any],
        loss_reduction: MegatronLossReduction,
        reduced: Tensor | dict[str, Tensor],
    ) -> None:
        """Log after MegatronReductionLoss.reduce is called.

        Expected microbatch_outputs to be a list of dicts with the following keys:
            - batch: dict of tensors with the following keys:
                - labels: [b s]
                - loss_mask: [b s]; 1 means included 0 means ignored
            - forward_out: dict of tensors with the following keys:
                - token_logits: [b s vocab]
        """
        if step.trainer.training and not self.log_train:
            return

        if not parallel_state.is_pipeline_last_stage():
            return

        assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
        assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
        assert (
            len(microbatch_outputs) == step.num_microbatches
        ), "microbatch_outputs length does not match num_microbatches"
        labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
        loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
        token_logits = self._pad_to_max_length(microbatch_outputs, "forward_out", "token_logits")

        unreduced_token_loss = unreduced_token_loss_fn(
            token_logits.clone(),  # unreduced_token_loss_fn has inplace operation on token_logits
            labels.clone(),
        )  # [b s]

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
        else:
            raise NotImplementedError("Context parallel perplexity logging is not supported yet")

        if self.log_val and not step.trainer.training:
            step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
        elif self.log_train and step.trainer.training:
            step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)

__init__(log_train=False, log_val=True)

Initialize PerplexityLoggingCallback.

Parameters:

Name Type Description Default
log_train bool

whether to log train perplexity. Defaults to False.

False
log_val bool

whether to log validation perplexity. Defaults to True.

True
Source code in bionemo/llm/lightning.py
312
313
314
315
316
317
318
319
320
321
def __init__(self, log_train: bool = False, log_val: bool = True):
    """Initialize PerplexityLoggingCallback.

    Args:
        log_train: whether to log train perplexity. Defaults to False.
        log_val: whether to log validation perplexity. Defaults to True.
    """
    super().__init__()
    self.log_train = log_train
    self.log_val = log_val

on_megatron_reduce_microbatches_end(step, microbatch_outputs, loss_reduction, reduced)

Log after MegatronReductionLoss.reduce is called.

Expected microbatch_outputs to be a list of dicts with the following keys
  • batch: dict of tensors with the following keys:
    • labels: [b s]
    • loss_mask: [b s]; 1 means included 0 means ignored
  • forward_out: dict of tensors with the following keys:
    • token_logits: [b s vocab]
Source code in bionemo/llm/lightning.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
@override
def on_megatron_reduce_microbatches_end(
    self,
    step: MegatronStep,
    microbatch_outputs: List[Any],
    loss_reduction: MegatronLossReduction,
    reduced: Tensor | dict[str, Tensor],
) -> None:
    """Log after MegatronReductionLoss.reduce is called.

    Expected microbatch_outputs to be a list of dicts with the following keys:
        - batch: dict of tensors with the following keys:
            - labels: [b s]
            - loss_mask: [b s]; 1 means included 0 means ignored
        - forward_out: dict of tensors with the following keys:
            - token_logits: [b s vocab]
    """
    if step.trainer.training and not self.log_train:
        return

    if not parallel_state.is_pipeline_last_stage():
        return

    assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
    assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
    assert (
        len(microbatch_outputs) == step.num_microbatches
    ), "microbatch_outputs length does not match num_microbatches"
    labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
    loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
    token_logits = self._pad_to_max_length(microbatch_outputs, "forward_out", "token_logits")

    unreduced_token_loss = unreduced_token_loss_fn(
        token_logits.clone(),  # unreduced_token_loss_fn has inplace operation on token_logits
        labels.clone(),
    )  # [b s]

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
    else:
        raise NotImplementedError("Context parallel perplexity logging is not supported yet")

    if self.log_val and not step.trainer.training:
        step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
    elif self.log_train and step.trainer.training:
        step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)

batch_collator(batches)

Takes a sequence of batches and collates them into a single batch. This is distinct from the standard pytorch default_collator since it does not add the batch dimension, it's assumed the batch dimension is already present in the input, as would be the case when parallelizing across minibatches.

IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

Examples:

Outer container = Dict: [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} Outer container = List: [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] Outer container = Tuple: ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

Parameters:

Name Type Description Default
batches Optional[Sequence[ReductionT]]

sequence of batches to collate into a single batch.

required

Returns:

Type Description
Optional[ReductionT]

A single batch of the same type as the elements of your input sequence.

Source code in bionemo/llm/lightning.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def batch_collator(batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]]) -> Optional[ReductionT]:
    """Takes a sequence of batches and collates them into a single batch.
        This is distinct from the standard pytorch default_collator since it does
        not add the batch dimension, it's assumed the batch
        dimension is already present in the input, as would be the case when
        parallelizing across minibatches.

    IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type,
    there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

    Examples:
        Outer container = Dict:
            [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])}
        Outer container = List:
            [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])]
        Outer container = Tuple:
            ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

    Args:
        batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch.

    Returns:
        A single batch of the same type as the elements of your input sequence.
    """  # noqa: D205
    match batches:
        case [Tensor(), *_]:
            return torch.cat(batches, dim=0)
        case [dict(), *_]:
            return {key: batch_collator([batch[key] for batch in batches]) for key in batches[0]}
        case [tuple(), *_]:
            return tuple(batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0])))
        case [list(), *_]:
            return [batch_collator([batch[i] for batch in batches]) for i in range(len(batches[0]))]
        case None:
            return None
        case []:
            raise ValueError("Cannot process an empty sequence")
        case _:
            raise ValueError("Unsupported input structure in batch_collator")

default_megatron_optimizer()

Default distributed optimizer uses Adam with a 1e-4 learning rate.

Source code in bionemo/llm/lightning.py
299
300
301
302
303
def default_megatron_optimizer() -> MegatronOptimizerModule:
    """Default distributed optimizer uses Adam with a 1e-4 learning rate."""
    return MegatronOptimizerModule(
        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
    )

some_first(seq)

Returns the first non-None value from the sequence or fails

Source code in bionemo/llm/lightning.py
54
55
56
57
58
59
def some_first(seq: Iterable[Optional[T]]) -> T:
    """Returns the first non-None value from the sequence or fails"""  # noqa: D415
    for s in seq:
        if s is not None:
            return s
    raise ValueError("non-None value not found")