Skip to content

Lightning

BertBatch

Bases: BertBatchCore

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
83
84
85
86
class BertBatch(BertBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

BertBatchCore

Bases: TypedDict

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
71
72
73
74
75
class BertBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    text: Tensor
    attention_mask: Tensor

BertModel

Bases: Protocol[DataT]

Interface for BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
57
58
59
60
61
62
63
64
65
66
67
68
class BertModel(Protocol[DataT]):
    """Interface for BERT-like models."""

    def forward(
        self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
    ) -> DataT:
        """Inference for BERT-like models.

        Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
        and the original sequence lengths if the sequences are packed into a dense batch.
        """
        ...

forward(input_ids, attention_mask, packed_seq_params=None)

Inference for BERT-like models.

Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input, and the original sequence lengths if the sequences are packed into a dense batch.

Source code in bionemo/llm/model/biobert/lightning.py
60
61
62
63
64
65
66
67
68
def forward(
    self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
) -> DataT:
    """Inference for BERT-like models.

    Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
    and the original sequence lengths if the sequences are packed into a dense batch.
    """
    ...

BioBertLightningModule

Bases: LightningModule, IOMixinWithGettersSetters, ConnectorMixin, LightningPassthroughPredictionMixin

Source code in bionemo/llm/model/biobert/lightning.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
class BioBertLightningModule(
    pl.LightningModule, iom.IOMixinWithGettersSetters, nlio.ConnectorMixin, LightningPassthroughPredictionMixin
):
    def __init__(
        self,
        config: MegatronBioNeMoTrainableModelConfig,
        # TODO: Add transformer_layer_spec when we update mcore
        tokenizer: Optional[TokenizerSpec] = None,
        optimizer: MegatronOptimizerModule = MegatronOptimizerModule(
            config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
        ),
        data_step_function: DataStepFunction = biobert_data_step,
        forward_step_function: ForwardStepFunction = bert_forward_step,
        model_transform: Callable | None = None,
    ):
        """A pytorch lightning module for BioBert-derived models. This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
        To change the your loss, pass in a different config object that returns a different loss reduction class. To change your model and what it outputs,
        pass in a different config object that returns a different model. Do not modify this function unless you need to change higher level logic. You may
        need to modify the various step and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some of
        those functions may need to be refactored out into the config object or a different place so that they live closer to the model definition.

        Args:
            config (MegatronBioNeMoTrainableModelConfig): The model configuration object.
            tokenizer (Optional[TokenizerSpec], optional): The tokenizer object. Defaults to None.
            optimizer (MegatronOptimizerModule, optional): The optimizer object. Defaults to MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True)).
            data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
            forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
            model_transform (Callable, optional): The model transform function. Defaults to None.
        """  # noqa: D205
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        self.loss_reduction_class = config.get_loss_reduction_class()
        # TODO replace the self.configure_optimizer call with the optimizer below
        #  once it all works. This is the future direction for how things are going.

        self.optim = optimizer
        self.optim.connect(self)  # This will bind the `configure_optimizers` method
        self.data_step_function: DataStepFunction = data_step_function
        self.forward_step_function: ForwardStepFunction = forward_step_function
        if model_transform is not None:
            self.model_transform = model_transform

    def configure_model(self) -> None:
        self.module = self.config.configure_model(self.tokenizer)

    # This is now replaced by the init hook on self.optimizer
    # def configure_optimizers(self) -> Optimizer:
    #     return bert_default_optimizer(self)

    def forward(
        self,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        """Call the forward method of the underlying model, and return whatever it outputs."""
        output_tensor = self.module(*args, **kwargs)  # for now just pass through to the underlying model
        return output_tensor

    def data_step(self, dataloader_iter) -> DataStepOutput:
        return self.data_step_function(dataloader_iter)

    def forward_step(self, batch) -> DataT:
        return self.forward_step_function(self, batch)

    def training_step(self, batch, batch_idx=None) -> DataT:
        # In mcore the loss-function is part of the forward-pass (when labels are provided)
        return self.forward_step(batch)

    def validation_step(self, batch, batch_idx=None) -> DataT:
        # In mcore the loss-function is part of the forward-pass (when labels are provided)
        return self.forward_step(batch)

    def predict_step(self, batch, batch_idx=None) -> DataT:
        return self.forward_step(batch)

    def training_loss_reduction(self) -> MegatronLossReduction:
        # This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss
        #  This function will
        return self.loss_reduction_class()

    # The predict step comes from the LightningPassthroughPredictionMixin

    def validation_loss_reduction(self) -> MegatronLossReduction:
        return self.loss_reduction_class(validation_step=True)

    def test_loss_reduction(self) -> MegatronLossReduction:
        return self.loss_reduction_class(validation_step=True)

    def copy(self) -> "BioBertLightningModule":
        return self.__class__(
            self.config, self.tokenizer, self.optim, self.data_step_function, self.forward_step_function
        )

__init__(config, tokenizer=None, optimizer=MegatronOptimizerModule(config=OptimizerConfig(lr=0.0001, optimizer='adam', use_distributed_optimizer=True)), data_step_function=biobert_data_step, forward_step_function=bert_forward_step, model_transform=None)

A pytorch lightning module for BioBert-derived models. This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions. To change the your loss, pass in a different config object that returns a different loss reduction class. To change your model and what it outputs, pass in a different config object that returns a different model. Do not modify this function unless you need to change higher level logic. You may need to modify the various step and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some of those functions may need to be refactored out into the config object or a different place so that they live closer to the model definition.

Parameters:

Name Type Description Default
config MegatronBioNeMoTrainableModelConfig

The model configuration object.

required
tokenizer Optional[TokenizerSpec]

The tokenizer object. Defaults to None.

None
optimizer MegatronOptimizerModule

The optimizer object. Defaults to MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True)).

MegatronOptimizerModule(config=OptimizerConfig(lr=0.0001, optimizer='adam', use_distributed_optimizer=True))
data_step_function DataStepFunction

The data step function. Defaults to biobert_data_step.

biobert_data_step
forward_step_function ForwardStepFunction

The forward step function. Defaults to bert_forward_step.

bert_forward_step
model_transform Callable

The model transform function. Defaults to None.

None
Source code in bionemo/llm/model/biobert/lightning.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def __init__(
    self,
    config: MegatronBioNeMoTrainableModelConfig,
    # TODO: Add transformer_layer_spec when we update mcore
    tokenizer: Optional[TokenizerSpec] = None,
    optimizer: MegatronOptimizerModule = MegatronOptimizerModule(
        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
    ),
    data_step_function: DataStepFunction = biobert_data_step,
    forward_step_function: ForwardStepFunction = bert_forward_step,
    model_transform: Callable | None = None,
):
    """A pytorch lightning module for BioBert-derived models. This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
    To change the your loss, pass in a different config object that returns a different loss reduction class. To change your model and what it outputs,
    pass in a different config object that returns a different model. Do not modify this function unless you need to change higher level logic. You may
    need to modify the various step and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some of
    those functions may need to be refactored out into the config object or a different place so that they live closer to the model definition.

    Args:
        config (MegatronBioNeMoTrainableModelConfig): The model configuration object.
        tokenizer (Optional[TokenizerSpec], optional): The tokenizer object. Defaults to None.
        optimizer (MegatronOptimizerModule, optional): The optimizer object. Defaults to MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True)).
        data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
        forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
        model_transform (Callable, optional): The model transform function. Defaults to None.
    """  # noqa: D205
    super().__init__()
    self.config = config
    self.tokenizer = tokenizer
    self.loss_reduction_class = config.get_loss_reduction_class()
    # TODO replace the self.configure_optimizer call with the optimizer below
    #  once it all works. This is the future direction for how things are going.

    self.optim = optimizer
    self.optim.connect(self)  # This will bind the `configure_optimizers` method
    self.data_step_function: DataStepFunction = data_step_function
    self.forward_step_function: ForwardStepFunction = forward_step_function
    if model_transform is not None:
        self.model_transform = model_transform

forward(*args, **kwargs)

Call the forward method of the underlying model, and return whatever it outputs.

Source code in bionemo/llm/model/biobert/lightning.py
335
336
337
338
339
340
341
342
def forward(
    self,
    *args,
    **kwargs,
) -> torch.Tensor:
    """Call the forward method of the underlying model, and return whatever it outputs."""
    output_tensor = self.module(*args, **kwargs)  # for now just pass through to the underlying model
    return output_tensor

SequenceBatch

Bases: SequenceBatchCore

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
95
96
97
98
99
class SequenceBatch(SequenceBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens_argmin: Tensor
    max_seqlen: Tensor

SequenceBatchCore

Bases: TypedDict

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
89
90
91
92
class SequenceBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

bert_default_optimizer(model)

Returns the default optimizer for the BERT model.

Parameters:

Name Type Description Default
model Module

The BERT model.

required

Returns:

Type Description
FusedAdam

The default optimizer initialized for this BERT module's parameters.

FusedAdam

Uses a learning rate of 1e-4 and weight decay of 1e-2.

Source code in bionemo/llm/model/biobert/lightning.py
190
191
192
193
194
195
196
197
198
199
200
def bert_default_optimizer(model: torch.nn.Module) -> FusedAdam:
    """Returns the default optimizer for the BERT model.

    Args:
        model: The BERT model.

    Returns:
        The default optimizer initialized for this BERT module's parameters.
        Uses a learning rate of 1e-4 and weight decay of 1e-2.
    """
    return FusedAdam(model.parameters(), lr=1e-4, weight_decay=0.01)

bert_forward_step(model, batch)

Performs the model's forward pass using the batch, for Megatron compatibility.

This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the model for forward pass efficiency.

Source code in bionemo/llm/model/biobert/lightning.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def bert_forward_step(model: BertModel[DataT], batch: BertBatch) -> DataT:
    """Performs the model's forward pass using the batch, for Megatron compatibility.

    This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's
    forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the
    model for forward pass efficiency.
    """
    if "cu_seqlens" in batch:
        forward_results = model.forward(
            input_ids=batch["text"],
            attention_mask=batch["attention_mask"],
            packed_seq_params=get_packed_seq_params(cast(SequenceBatch, batch)),
        )
    else:
        forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
    # TODO support losses that also include the binary head, this means doing something more fancy than the one
    #      default GPT reduction function above MaskedTokenLossReduction()
    return forward_results

biobert_data_step(dataloader_iter)

Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator. only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage. TODO document how parallel_state pipeline stages work.

Parameters:

Name Type Description Default
dataloader_iter

An iterator over the dataloader.

required

Returns:

Name Type Description
output Dict[str, Tensor]

A dictionary of this batch limiting to relevant keys.

Source code in bionemo/llm/model/biobert/lightning.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def biobert_data_step(dataloader_iter) -> Dict[str, Tensor]:
    """Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator.
        only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage.
        TODO document how parallel_state pipeline stages work.

    Args:
        dataloader_iter: An iterator over the dataloader.

    Returns:
        output: A dictionary of this batch limiting to relevant keys.

    """  # noqa: D205
    # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
    # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842

    batch = next(dataloader_iter)

    if isinstance(batch, tuple) and len(batch) == 3:
        _batch: dict = batch[0]
    else:
        _batch = batch

    required_keys = set()
    required_keys.add("attention_mask")
    if parallel_state.is_pipeline_first_stage():
        required_keys.add("text")
    if parallel_state.is_pipeline_last_stage():
        required_keys.update(("labels", "loss_mask", "types", "is_random"))
    # if self.get_attention_mask_from_fusion:
    #     required_keys.remove('attention_mask')

    _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
    # slice batch along sequence dimension for context parallelism
    output = get_batch_on_this_context_parallel_rank(_batch)

    return output

biobert_lightning_module(config, optimizer=None, tokenizer=None, data_step=biobert_data_step, forward_step=bert_forward_step, model_transform=None, **model_construct_args)

A pytorch lightning module for BioBert-derived models.

This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions. To change your loss, pass in a different config object that returns a different loss reduction class. To change your model and what it outputs, pass in a different config object that returns a different model. Do not modify this function unless you need to change higher level logic. You may need to modify the various step and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some of those functions may need to be refactored out into the config object or a different place so that they live closer to the model definition.

Source code in bionemo/llm/model/biobert/lightning.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def biobert_lightning_module(
    config: BioBertConfig[MegatronBioBertModel, MegatronLossReduction],
    optimizer: Optional[MegatronOptimizerModule] = None,
    tokenizer: Optional[TokenizerSpec | PreTrainedTokenizerBase] = None,
    data_step: DataStep = biobert_data_step,
    forward_step: ForwardStep = bert_forward_step,
    model_transform: Optional[Callable] = None,
    **model_construct_args,
) -> BionemoLightningModule[MegatronBioBertModel, MegatronLossReduction]:
    """A pytorch lightning module for BioBert-derived models.

    This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
    To change your loss, pass in a different config object that returns a different loss reduction class.
    To change your model and what it outputs, pass in a different config object that returns a different model.
    Do not modify this function unless you need to change higher level logic. You may need to modify the various step
    and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some
    of those functions may need to be refactored out into the config object or a different place so that they live
    closer to the model definition.
    """
    return BionemoLightningModule(
        config=config,
        optimizer=optimizer if optimizer is not None else default_megatron_optimizer(),
        data_step=data_step,
        forward_step=forward_step,
        tokenizer=tokenizer,
        model_transform=model_transform,
        **model_construct_args,
    )

get_batch_on_this_context_parallel_rank(batch, in_place=True)

Ensures that the input batch is in the right format for context parallel rank.

Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1. Otherwise, the batch is returned as-is.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

The input batch data.

required
in_place bool

If true, then the input is mutated. The returned dict is a reference to the input. Otherwise, the input data is always shallow-copied and this copy is modified and returned.

True

Returns:

Name Type Description
dict Dict[str, Tensor]

The modified batch data based on the context parallel rank.

Source code in bionemo/llm/model/biobert/lightning.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def get_batch_on_this_context_parallel_rank(batch: Dict[str, Tensor], in_place: bool = True) -> Dict[str, Tensor]:
    """Ensures that the input batch is in the right format for context parallel rank.

    Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1.
    Otherwise, the batch is returned as-is.


    Args:
        batch: The input batch data.
        in_place: If true, then the input is mutated. The returned dict is a reference to the input.
                  Otherwise, the input data is always shallow-copied and this copy is modified and returned.

    Returns:
        dict: The modified batch data based on the context parallel rank.
    """
    if not in_place:
        batch: dict[str, Tensor] = dict(**batch)

    if cp_size := parallel_state.get_context_parallel_world_size() > 1:
        num_valid_tokens_in_ub: Tensor | None = None
        if "loss_mask" in batch and batch["loss_mask"] is not None:
            num_valid_tokens_in_ub = batch["loss_mask"].sum()

        cp_rank = parallel_state.get_context_parallel_rank()
        for key, val in batch.items():
            if val is not None:
                seq_dim = 1 if key != "attention_mask" else 2
                _val = val.view(
                    *val.shape[0:seq_dim],
                    2 * cp_size,
                    val.shape[seq_dim] // (2 * cp_size),
                    *val.shape[(seq_dim + 1) :],
                )
                index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
                    non_blocking=True
                )
                _val = _val.index_select(seq_dim, index)
                _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
                batch[key] = _val
        batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub  # type: ignore

    return batch

get_packed_seq_params(batch)

Get the packed sequence parameters for the given batch.

This function should only be called if cu_seqlens is defined in the batch.

Parameters:

Name Type Description Default
batch SequenceBatch

The input batch to pack.

required

Returns:

Name Type Description
PackedSeqParams PackedSeqParams

The packed sequence parameters containing the following attributes: - cu_seqlens_q (Tensor): The sequence lengths for query. - cu_seqlens_kv (Tensor): The sequence lengths for key and value. - max_seqlen_q (Tensor, optional): The maximum sequence length for query. - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value. - qkv_format (str): The format of query, key, and value tensors.

Source code in bionemo/llm/model/biobert/lightning.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def get_packed_seq_params(batch: SequenceBatch) -> PackedSeqParams:
    """Get the packed sequence parameters for the given batch.

    This function should only be called if `cu_seqlens` is defined in the batch.

    Args:
        batch: The input batch to pack.

    Returns:
        PackedSeqParams: The packed sequence parameters containing the following attributes:
            - cu_seqlens_q (Tensor): The sequence lengths for query.
            - cu_seqlens_kv (Tensor): The sequence lengths for key and value.
            - max_seqlen_q (Tensor, optional): The maximum sequence length for query.
            - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value.
            - qkv_format (str): The format of query, key, and value tensors.

    """
    cu_seqlens = batch["cu_seqlens"].squeeze()  # remove batch size dimension (mbs=1)
    # remove -1 "paddings" added in collate_fn
    if cu_seqlens_argmin := batch.get("cu_seqlens_argmin", None) is not None:
        # pre-compute cu_seqlens_argmin in dataset class for perf
        cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
    else:
        cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]

    # pre-compute max_seqlens in dataset class for perf
    max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None

    # these args are passed eventually into TEDotProductAttention.forward()
    return PackedSeqParams(
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_kv=cu_seqlens,
        max_seqlen_q=max_seqlen,
        max_seqlen_kv=max_seqlen,
        qkv_format="thd",
    )