core.models.retro.model#
Retro Model.
Module Contents#
Classes#
Retro Model. |
API#
- class core.models.retro.model.RetroModel(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- vocab_size: int,
- max_sequence_length: int,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- position_embedding_type: Literal[learned_absolute, rope, mrope, yarn, none] = 'learned_absolute',
- rotary_percent: float = 1.0,
- rotary_base: int = 10000,
- rope_scaling: bool = False,
- rope_scaling_factor: float = 8.0,
- scatter_embedding_sequence_parallel: bool = True,
- seq_len_interpolation_factor: Optional[float] = None,
- mtp_block_spec: Optional[megatron.core.transformer.spec_utils.ModuleSpec] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- vp_stage: Optional[int] = None,
Bases:
megatron.core.models.gpt.GPTModelRetro Model.
A Retro model mostly re-uses the GPTModel interface, with the only difference being the embedding of the ‘context’ this is used by Retro for processing neighbor tokens. This embedded context is then forwarded to the Transformer Block.
Initialization
- forward(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- context_input_ids: torch.Tensor = None,
- context_position_ids: torch.Tensor = None,
- context_mask: torch.Tensor = None,
- decoder_input: torch.Tensor = None,
- labels: torch.Tensor = None,
- inference_context: megatron.core.inference.contexts.BaseInferenceContext = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
RetroModel forward method.
Foward input tokens & mask, along with neighbor tokens & mask, through the Retro model..
- Parameters:
input_ids (Tensor) – Input token IDs.
position_ids (Tensor) – Input position IDs.
attention_mask (Tensor) – Input attention mask.
context_input_ids (Tensor) – Context (i.e., neighbor) token IDs.
context_position_ids (Tensor) – Context (i.e., neighbor) position IDs.
context_mask (Tensor) – Context (i.e., neighbor) attention mask.
decoder_input (Tensor) – When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor) – The labels of dimension [batch size, seq length].
inference_context (BaseInferenceContext) – Inference context.
- Returns:
Output tensor of forward pass.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[Dict] = None,
Get sharded state dict.
- Parameters:
prefix (str) – Module name prefix.
sharded_offsets (tuple) – Offsets of local shard within global tensor.
metadata (Optional[Dict]) – Shard metadata.
- Returns:
A
?