Skip to content

Dtypes

get_autocast_dtype(precision)

Returns the torch dtype corresponding to the given precision.

Parameters:

Name Type Description Default
precision PrecisionTypes

The precision type.

required

Returns:

Type Description
dtype

torch.dtype: The torch dtype corresponding to the given precision.

Raises:

Type Description
ValueError

If the precision is not supported.

Source code in bionemo/core/utils/dtypes.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype:
    """Returns the torch dtype corresponding to the given precision.

    Args:
        precision: The precision type.

    Returns:
        torch.dtype: The torch dtype corresponding to the given precision.

    Raises:
        ValueError: If the precision is not supported.
    """
    # TODO move this to a utilities folder, or find/import the function that does this in NeMo
    if precision == "fp16":
        return torch.float16
    elif precision == "bf16":
        return torch.bfloat16
    elif precision == "fp32":
        return torch.float32
    elif precision == "16-mixed":
        return torch.float16
    elif precision == "fp16-mixed":
        return torch.float16
    elif precision == "bf16-mixed":
        return torch.bfloat16
    elif precision == "fp32-mixed":
        return torch.float32
    elif precision == 16:
        return torch.float16
    elif precision == 32:
        return torch.float32
    else:
        raise ValueError(f"Unsupported precision: {precision}")