bridge.training.utils.wandb_utils#

Module Contents#

Functions#

on_save_checkpoint_success

Callback executed after a checkpoint is successfully saved.

on_load_checkpoint_success

Callback executed after a checkpoint is successfully loaded.

_get_wandb_artifact_tracker_filename

Wandb artifact tracker file rescords the latest artifact wandb entity and project

_get_artifact_name_and_version

API#

bridge.training.utils.wandb_utils.on_save_checkpoint_success(
checkpoint_path: str,
save_dir: str,
iteration: int,
wandb_writer: Optional[Any],
) None#

Callback executed after a checkpoint is successfully saved.

If a wandb writer is provided, logs the checkpoint as a wandb artifact, referencing the local file path. Also saves a tracker file containing the wandb entity/project for later use.

Parameters:
  • checkpoint_path – The path to the specific checkpoint file/directory saved.

  • save_dir – The base directory where checkpoints are being saved.

  • iteration – The training iteration at which the checkpoint was saved.

  • wandb_writer – The wandb writer instance (e.g., wandb.run). If None, this function is a no-op.

bridge.training.utils.wandb_utils.on_load_checkpoint_success(
checkpoint_path: str,
load_dir: str,
wandb_writer: Optional[Any],
) None#

Callback executed after a checkpoint is successfully loaded.

If a wandb writer is provided, attempts to mark the corresponding wandb artifact as used. It reads the entity/project from the tracker file saved during the checkpoint save process.

Parameters:
  • checkpoint_path – The path to the specific checkpoint file/directory loaded.

  • load_dir – The base directory from which the checkpoint was loaded.

  • wandb_writer – The wandb writer instance (e.g., wandb.run). If None, or if artifact tracking fails, this function is a no-op.

bridge.training.utils.wandb_utils._get_wandb_artifact_tracker_filename(save_dir: str) pathlib.Path#

Wandb artifact tracker file rescords the latest artifact wandb entity and project

bridge.training.utils.wandb_utils._get_artifact_name_and_version(
save_dir: pathlib.Path,
checkpoint_path: pathlib.Path,
) tuple[str, str]#