Skip to content

Testing callbacks

MetadataSaveCallback

Bases: Callback

A callback that saves metadata about the current training at the second validation epoch.

Source code in bionemo/testing/testing_callbacks.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class MetadataSaveCallback(Callback):
    """A callback that saves metadata about the current training at the second validation epoch."""

    def __init__(
        self, metadata_path: pathlib.Path, metrics_getter: dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]
    ):
        """Initialises callback with path and called information.

        Args:
            metadata_path (pathlib.Path): Path where the metadata will be saved.
            metrics_getter (dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]): A dictionary of metadata keys and their corresponding functions.

        See Also: bionemo.testing.stop_and_go
        """
        self.metadata_path = metadata_path
        self.pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
        self.called = False  # indicates if callback was already called
        self.metrics_getter = metrics_getter

    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
        """Set up the testing callbacks and removes lingering metadata."""
        super().setup(trainer, pl_module, stage)
        if trainer.is_global_zero and os.path.exists(self.pickle_file_path):
            os.remove(self.pickle_file_path)

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
        """Stores requisite metadata at the end of the first non-warmup validation epoch.

        Executes on the second validation epoch -only- due to how warmups are handled. May not work as intended in the
        absence of a warmup.

        Args:
            trainer (Trainer): The Lightning Trainer object.
            pl_module (LightningModule): The LightningModule being trained.

        Notes:
            - If `called` is True and `trainer.is_global_zero` is True, the function saves metadata to compare after resuming with a checkpoint.
            - The metadata is obtained using the `metrics_getter` dict and results are saved as a pickle file.

        """
        if self.called and trainer.is_global_zero:
            # save metadata to compare to after resuming with checkpoint
            metadata = {}
            for metadata_key, func in self.metrics_getter.items():
                metadata_value = func(trainer, pl_module)
                metadata[metadata_key] = metadata_value
            # prepare paths for metadata save
            pickle_file_path = self.pickle_file_path
            os.makedirs(os.path.dirname(pickle_file_path), exist_ok=True)
            with open(pickle_file_path, "wb") as metadata_file:
                pickle.dump(metadata, metadata_file)
            # check that pickle file was saved correctly
            assert os.path.isfile(pickle_file_path), f"No file found at {pickle_file_path}"
        else:
            # first time this callback is called is before the ModelCheckpoint callback
            # since that one is always executed last. Therefore, we skip the first validation
            # round and only save metadata at the second validation round
            self.called = True

__init__(metadata_path, metrics_getter)

Initialises callback with path and called information.

Parameters:

Name Type Description Default
metadata_path Path

Path where the metadata will be saved.

required
metrics_getter dict[str, Callable[[Trainer, LightningModule], Any]]

A dictionary of metadata keys and their corresponding functions.

required

See Also: bionemo.testing.stop_and_go

Source code in bionemo/testing/testing_callbacks.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self, metadata_path: pathlib.Path, metrics_getter: dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]
):
    """Initialises callback with path and called information.

    Args:
        metadata_path (pathlib.Path): Path where the metadata will be saved.
        metrics_getter (dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]): A dictionary of metadata keys and their corresponding functions.

    See Also: bionemo.testing.stop_and_go
    """
    self.metadata_path = metadata_path
    self.pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
    self.called = False  # indicates if callback was already called
    self.metrics_getter = metrics_getter

on_validation_epoch_end(trainer, pl_module)

Stores requisite metadata at the end of the first non-warmup validation epoch.

Executes on the second validation epoch -only- due to how warmups are handled. May not work as intended in the absence of a warmup.

Parameters:

Name Type Description Default
trainer Trainer

The Lightning Trainer object.

required
pl_module LightningModule

The LightningModule being trained.

required
Notes
  • If called is True and trainer.is_global_zero is True, the function saves metadata to compare after resuming with a checkpoint.
  • The metadata is obtained using the metrics_getter dict and results are saved as a pickle file.
Source code in bionemo/testing/testing_callbacks.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
    """Stores requisite metadata at the end of the first non-warmup validation epoch.

    Executes on the second validation epoch -only- due to how warmups are handled. May not work as intended in the
    absence of a warmup.

    Args:
        trainer (Trainer): The Lightning Trainer object.
        pl_module (LightningModule): The LightningModule being trained.

    Notes:
        - If `called` is True and `trainer.is_global_zero` is True, the function saves metadata to compare after resuming with a checkpoint.
        - The metadata is obtained using the `metrics_getter` dict and results are saved as a pickle file.

    """
    if self.called and trainer.is_global_zero:
        # save metadata to compare to after resuming with checkpoint
        metadata = {}
        for metadata_key, func in self.metrics_getter.items():
            metadata_value = func(trainer, pl_module)
            metadata[metadata_key] = metadata_value
        # prepare paths for metadata save
        pickle_file_path = self.pickle_file_path
        os.makedirs(os.path.dirname(pickle_file_path), exist_ok=True)
        with open(pickle_file_path, "wb") as metadata_file:
            pickle.dump(metadata, metadata_file)
        # check that pickle file was saved correctly
        assert os.path.isfile(pickle_file_path), f"No file found at {pickle_file_path}"
    else:
        # first time this callback is called is before the ModelCheckpoint callback
        # since that one is always executed last. Therefore, we skip the first validation
        # round and only save metadata at the second validation round
        self.called = True

setup(trainer, pl_module, stage)

Set up the testing callbacks and removes lingering metadata.

Source code in bionemo/testing/testing_callbacks.py
104
105
106
107
108
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
    """Set up the testing callbacks and removes lingering metadata."""
    super().setup(trainer, pl_module, stage)
    if trainer.is_global_zero and os.path.exists(self.pickle_file_path):
        os.remove(self.pickle_file_path)

RaiseAfterMetadataCallback

Bases: Callback

A callback that raises a StopAndGoException kills it if the metadata from the MetadataSaveCallback was saved successfully beforehand.

Use this callback for pytest based Stop and go tests.

Source code in bionemo/testing/testing_callbacks.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class RaiseAfterMetadataCallback(Callback):
    """A callback that raises a StopAndGoException kills it if the metadata from the MetadataSaveCallback was saved successfully beforehand.

    Use this callback for pytest based Stop and go tests.
    """

    def __init__(self, metadata_path: pathlib.Path):  # noqa: D107
        self.metadata_path = metadata_path

    def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int):
        """PTL callback that raises a StopAndGoException if metadata exists."""
        pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
        if os.path.exists(pickle_file_path):
            # Register the signal handler
            raise StopAndGoException("Terminating early, checkpoint exists.")

on_train_batch_start(trainer, pl_module, batch, batch_idx)

PTL callback that raises a StopAndGoException if metadata exists.

Source code in bionemo/testing/testing_callbacks.py
76
77
78
79
80
81
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int):
    """PTL callback that raises a StopAndGoException if metadata exists."""
    pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
    if os.path.exists(pickle_file_path):
        # Register the signal handler
        raise StopAndGoException("Terminating early, checkpoint exists.")

TestCheckpointIntegrityCallback

Bases: Callback

Callback that tests if current metrics match those saved in the associated metadata file.

This callback expects to be invoked only after resuming a model that used the MetadataSaveCallback. When training begins, it checks the value of each metric and compares to the metadata stored in the metadata pickle file. Any deviances are assumed to be a failure in restoration.

Source code in bionemo/testing/testing_callbacks.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class TestCheckpointIntegrityCallback(Callback):
    """Callback that tests if current metrics match those saved in the associated metadata file.

    This callback expects to be invoked _only_ after resuming a model that used the MetadataSaveCallback. When training begins, it checks the value of each metric and compares to the metadata stored in the metadata pickle file. Any deviances are assumed to be a failure in restoration.
    """

    def __init__(
        self, metadata_path: pathlib.Path, metrics_getter: dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]
    ):
        """Initialises callback with path and called information.

        Args:
            metadata_path (pathlib.Path): Path where the metadata will be saved.
            metrics_getter (dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]): A dictionary of metadata keys and their corresponding functions. Must be a subset of the dictionary passed to MetadataSaveCallback.
        """
        self.metadata_path = metadata_path
        self.metrics_getter = metrics_getter

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
        """Loads associated metadata and compares with current metrics."""
        pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
        # check that pickle file exists
        assert os.path.isfile(pickle_file_path), f"No file found at {pickle_file_path}"
        with open(pickle_file_path, "rb") as metadata_file:
            metadata_dict = pickle.load(metadata_file)
        current_metadata = {}
        for metadata_key, func in self.metrics_getter.items():
            expected_value = metadata_dict[metadata_key]
            current_value = func(trainer, pl_module)
            current_metadata[metadata_key] = current_value

        # TODO (SKH): Ideally this would collect _all_ failures instead of failing on the first one.
        for metadata_key in current_metadata:
            expected_value = metadata_dict[metadata_key]
            current_value = current_metadata[metadata_key]
            assert (
                expected_value == current_value
            ), f"Value mismatch for key {metadata_key}: stored_value={expected_value}, current_value={current_value}"
        # Cleanup
        os.remove(pickle_file_path)

__init__(metadata_path, metrics_getter)

Initialises callback with path and called information.

Parameters:

Name Type Description Default
metadata_path Path

Path where the metadata will be saved.

required
metrics_getter dict[str, Callable[[Trainer, LightningModule], Any]]

A dictionary of metadata keys and their corresponding functions. Must be a subset of the dictionary passed to MetadataSaveCallback.

required
Source code in bionemo/testing/testing_callbacks.py
151
152
153
154
155
156
157
158
159
160
161
def __init__(
    self, metadata_path: pathlib.Path, metrics_getter: dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]
):
    """Initialises callback with path and called information.

    Args:
        metadata_path (pathlib.Path): Path where the metadata will be saved.
        metrics_getter (dict[str, Callable[[pl.Trainer, pl.LightningModule], Any]]): A dictionary of metadata keys and their corresponding functions. Must be a subset of the dictionary passed to MetadataSaveCallback.
    """
    self.metadata_path = metadata_path
    self.metrics_getter = metrics_getter

on_train_start(trainer, pl_module)

Loads associated metadata and compares with current metrics.

Source code in bionemo/testing/testing_callbacks.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
    """Loads associated metadata and compares with current metrics."""
    pickle_file_path = os.path.join(self.metadata_path, "checkpoints/metadata.pkl")
    # check that pickle file exists
    assert os.path.isfile(pickle_file_path), f"No file found at {pickle_file_path}"
    with open(pickle_file_path, "rb") as metadata_file:
        metadata_dict = pickle.load(metadata_file)
    current_metadata = {}
    for metadata_key, func in self.metrics_getter.items():
        expected_value = metadata_dict[metadata_key]
        current_value = func(trainer, pl_module)
        current_metadata[metadata_key] = current_value

    # TODO (SKH): Ideally this would collect _all_ failures instead of failing on the first one.
    for metadata_key in current_metadata:
        expected_value = metadata_dict[metadata_key]
        current_value = current_metadata[metadata_key]
        assert (
            expected_value == current_value
        ), f"Value mismatch for key {metadata_key}: stored_value={expected_value}, current_value={current_value}"
    # Cleanup
    os.remove(pickle_file_path)

compute_biobert_loss_singlegpu(model, dl)

Computes the loss for BioBert models on a single GPU.

This will not function in multi-gpu settings nor with models that do not conform to BioBert.

Parameters:

Name Type Description Default
model Module

The Biobert model.

required
dl DataLoader

The data loader.

required

Returns:

Name Type Description
float

The mean loss.

See Also: - :class: BioBertModel

Source code in bionemo/testing/testing_callbacks.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def compute_biobert_loss_singlegpu(model, dl: DataLoader):
    """Computes the loss for BioBert models on a single GPU.

    This will not function in multi-gpu settings nor with models that do not conform to BioBert.

    Args:
        model (torch.nn.Module): The Biobert model.
        dl (torch.utils.data.DataLoader): The data loader.

    Returns:
        float: The mean loss.

    See Also:
    - :class: BioBertModel
    """
    n, loss = 0, 0.0
    model.eval()
    # batch = next(iter(dl))
    batch = model.data_step(iter(dl))
    result = model(
        input_ids=batch["text"].cuda(),  # 'tokens' also a valid input for MockGPTDataModule
        attention_mask=batch["attention_mask"].cuda(),
    )
    loss_mask = batch["loss_mask"].cuda()
    # Not guaranteed i guess?
    logits = result["token_logits"]
    target = batch["labels"].cuda()
    loss += F.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum")
    n += loss_mask.sum()

    mean_loss: float = (loss / n).detach().cpu().numpy().item()
    model.train()
    return mean_loss