bridge.training.utils.weight_decay_utils#

Module Contents#

Functions#

get_no_weight_decay_cond

Get the no weight decay condition function.

API#

bridge.training.utils.weight_decay_utils.get_no_weight_decay_cond(
no_weight_decay_cond_type: str,
default_skip_embedding_weight_decay: bool,
) Callable[[str, torch.Tensor], bool]#

Get the no weight decay condition function.