core.models.vision.radio#

Module Contents#

Classes#

RADIOViTModel

RADIO ViT vision model.

Functions#

fp8_pad_hook

FP8 requires class token length to be a multiple of 16 (for this model).

API#

class core.models.vision.radio.RADIOViTModel(
transformer_config: megatron.core.transformer.transformer_config.TransformerConfig,
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
ln_pre_impl: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None,
ln_post_impl: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None,
use_mask_token: bool = False,
add_class_token: bool = True,
class_token_len: int = 8,
patch_dim: int = 16,
img_h: int = 224,
img_w: int = 224,
max_img_h: int = 2048,
max_img_w: int = 2048,
pos_dropout: int = 0,
has_cpe: bool = True,
embedder_bias: bool = False,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
)#

Bases: megatron.core.models.common.vision_module.vision_module.VisionModule

RADIO ViT vision model.

Parameters:
  • transformer_config (TransformerConfig) – Transformer config.

  • transformer_layer_spec (ModuleSpec) – Specifies module to use for transformer layers.

  • ln_pre_impl (ModuleSpec or type) – Specifies the layer norm type to use for ln_pre.

  • ln_post_impl (ModuleSpec or type) – Specifies the layer norm type to use for ln_post.

  • use_mask_token (bool, optional) – Whether to use RADIO mask token. Default to False.

  • add_class_token (bool, optional) – Include a class token. Defaults to True.

  • class_token_len (int) – Class token length. Defaults to 1 but 8 may be faster.

  • patch_dim (int) – Image patch size.

  • img_h (int) – Input image height.

  • img_w (int) – Input image width.

  • max_img_h (int) – Max input image height.

  • max_img_w (int) – Max input image width.

  • pos_dropout (int) – Positional encoding dropout value. Defaults to 0.

  • has_cpe – (bool): Whether to use conditional positional encoding. Defaults to True.

  • embedder_bias – (bool): Bias in embedder linear. Defaults to False.

Initialization

set_input_tensor(input_tensor: torch.Tensor) None#

Sets input tensor to the model.

Parameters:

input_tensor (Tensor) – Sets the input tensor for the model.

forward(
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) torch.Tensor#

Forward function of the RADIO ViT Model. This function passes the input tensors through the embedding layer and then the transformer.

Parameters:
  • x (torch.Tensor) – input data of shape [batch, img_h, img_w]

  • attention_mask (torch.Tensor with dtype=bool) – Attention mask to use.

Returns:

output after final transformer block of shape [b, s, h].

Return type:

x (torch.Tensor)

apply_pos_enc(
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) torch.Tensor#

Apply positional encoding to patches

get_pos_enc(
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) torch.Tensor#

Get positional encoding for certain input size

_get_pos_embeddings(
batch_size: int,
input_dims: Tuple[int, int],
)#

Get RADIO absolute positional embeddings

core.models.vision.radio.fp8_pad_hook(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)#

FP8 requires class token length to be a multiple of 16 (for this model).

Original model checkpoint may not be padded for FP8 so pad it here.