# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from contextlib import contextmanager

from typing import Optional
from functools import partial
from collections import OrderedDict

import torch
from torch.amp import autocast

import transformer_engine as te
from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.fp8 import get_default_fp8_recipe
import transformers
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel

import torch.nn.functional as F

"""
Top level description of the classes used in the tutorial from this file.
----------------------------------------------------------------------

HuggingFace Gemma Model implementation hierarchy:
----------------------------------
GemmaDecoderLayer:
├── self_attn:
│   ├── norm: (nn.LayerNorm)
│   ├── qkv_proj: (nn.Linear)
│   ├── attention: (SDPA, FlashAttention, etc.)
│   └── o_proj: (nn.Linear)
├── ffn:
│   ├── norm: (nn.LayerNorm)
│   ├── gate_proj: (nn.Linear)
│   ├── up_proj: (nn.Linear)
│   └── down_proj: (nn.Linear)

GemmaModel:
├── embed_tokens         : Token embedding layer
├── layers               : GemmaDecoderLayer × N
├── norm                 : GemmaRMSNorm
└── rotary_emb           : GemmaRotaryEmbedding

GemmaForCausalLM:
├── model                : instance of GemmaModel
├── lm_head              : (nn.Linear) hidden states to vocabulary logits for generation
└── generate             : generate method (input prompt -> GemmaForCausalLM -> next tokens)

How `generate()` works in HF's GemmaForCausalLM:
    1. prefill (input prompt -> model -> lm_head -> logits -> next token)
    2. loop until max_new_tokens:
        - next token -> model -> lm_head -> logits -> next token
    3. return all tokens

NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method.
      This is a common pattern in HF models.


TransformerEngine's Gemma Model Hierarchy:
----------------------------------------
HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way,
while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying
blocks of "transformer layer" are actually from TransformerEngine.

TEGemmaDecoderLayer (inherits from te.TransformerLayer):
├── te.MultiHeadAttention:
│   ├── linear_qkv: (te.LayerNormLinear)
│   ├── attention: (te.DotProductAttention)
│   └── out_proj: (te.LayerNormLinear)
├── te.LayerNormMLP:
│   ├── fc1: (te.LayerNormLinear)
│   ├── fc2: (te.Linear)
│   └── activation: (te.GeGLU)

To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which
subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods.

TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM)
├─ model                    : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N
├─ lm_head                  : directly inherited from HF's GemmaForCausalLM
├─ te_rope_emb              : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility)
├─ hidden_states_buffer     : shape [b, max_ctx, h]                             (static)
├─ generation_buffer        : shape [b, 1, h] (view of `hidden_states_buffer`)  (static)
├─ inference_params         : TransformerEngine KV cache
├─ model_context_phase      : GemmaModelWrapper  → uses (model, lm_head, inference_params) for full-sequence prefill
├─ model_generation_phase   : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode
└─ generate                 : generate method (input prompt -> TEGemmaForCausalLM -> next tokens)

Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and
"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed:

TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM)
├─ model                    : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N)
├─ lm_head                  : unchanged
├─ hidden_states_buffer     : unchanged
├─ generation_buffer        : unchanged
├─ inference_params         : unchanged
├─ record                   : utility function to record the graphed callable
├─ model_context_phase      : GraphedCallable(for Context/prefill) replaced by `record`
├─ model_generation_phase   : GraphedCallable(for Generation) replaced by `record`
└─ generate                 : unchanged

How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs:
    1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token)
    2. model_generation_phase:
        - loop until max_new_tokens:
            - next token -> model -> lm_head -> logits -> next token
    3. return all tokens

NOTE: In the tutorial, `record` is called when initializing the model.

Additional notes and clarifications
-----------------------------------
- Wrappers, not submodules:
  `model_context_phase` and `model_generation_phase` are convenience wrappers over the same
  `model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage,
  masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and
  KV-cache (`InferenceParams`) flow for TE-optimized inference.

- Buffer relationship:
  `hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view
  of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates
  `generation_buffer` in-place with next-token embeddings.

- Padding policy:
  Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end
  to match TE attention mask expectations and to keep shapes contiguous for capture/replay.

- CUDA Graphs specifics:
  `record()` captures two separate callables (context/prefill and generation) with fixed shapes and
  stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the
  functional behavior is identical; only allocation/pointer churn and CPU overhead are removed.
"""


class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `GemmaDecoderLayer` and easier to replace it in the code.

    Args:
        config: GemmaConfig
        args: positional args (for compatibility with `GemmaDecoderLayer`)
        kwargs: keyword args (for compatibility with `GemmaDecoderLayer`)
    """

    def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):

        self.gemma_config = config

        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=config.fuse_qkv_params,
            normalization="RMSNorm",
            activation="geglu",
            attn_input_format="bshd",
            num_gqa_groups=config.num_key_value_heads,
            kv_channels=self.gemma_config.head_dim,
            layer_number=(
                layer_idx + 1
            ),  # Layer numbers in TE starts from 1, not 0 like in the HF.
            zero_centered_gamma=True,
        )

    def forward(self, *args, **kwargs):  # We need to additionally pass positional encoding.

        # filter out HF specific args
        keys_to_remove = [
            "position_ids",
            "past_key_value",
            "output_attentions",
            "use_cache",
            "cache_position",
        ]
        for key in keys_to_remove:
            kwargs.pop(key, None)

        rope_emb = kwargs.pop("rope_emb", None)

        # Return tuple to be compatible with HF.
        return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),)


class GemmaModelWrapper(torch.nn.Module):
    """
    Encapsulates the HuggingFace GemmaModel class as a wrapper whose
    forward pass is compatible with CUDA Graphs.
    """

    def __init__(
        self,
        model: GemmaModel,
        dtype: torch.dtype,
        lm_head: torch.nn.Module,
    ):
        super().__init__()
        self.model = model
        self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype)
        self.lm_head = lm_head

    def set_inference_params(self, inference_params):
        self.inference_params = inference_params

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        attn_mask_type: str = "arbitrary",
        rope_emb: torch.Tensor = None,
    ):
        with torch.no_grad():
            # static operation - for CUDA graphs
            hidden_states.data[:] = hidden_states.data[:] * self.normalizer

            for i, decoder_layer in enumerate(self.model.layers):
                hidden_states.data[:] = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type,
                    inference_params=self.inference_params,
                    rope_emb=rope_emb,
                )[
                    0
                ]  # static copy - for CUDA graphs

        hidden_states.copy_(self.model.norm(hidden_states))  # static copy - for CUDA graphs
        logits = self.lm_head(hidden_states)

        # This is not needed for generation but is needed for training
        # or finetuning.
        if self.training:
            logits = logits.float()

        return logits


class GemmaGenerationWrapper(torch.nn.Module):
    """
    Gets token embeddings for a batch of single tokens, runs forward pass, and
    returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a
    subclass of `GemmaModel` since the model layers are simply reused here.
    """

    def __init__(
        self,
        model: GemmaModel,
        lm_head: torch.nn.Module,
        dtype: torch.dtype,
    ):
        super().__init__()
        self.model = model
        self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head)

    def set_inference_params(self, inference_params):
        self.inference_params = inference_params
        self.gemma_layers.set_inference_params(inference_params)

    def forward(
        self,
        hidden_states: torch.Tensor,
        mask: torch.Tensor = None,
        attn_mask_type: str = "arbitrary",
        rope_emb: torch.Tensor = None,
    ):
        logits = self.gemma_layers(
            hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb
        )

        assert logits.shape[0] == hidden_states.shape[0]  # b
        assert logits.shape[1] == hidden_states.shape[1]  # seq_len

        # Fetch the logits for the last token
        logits = logits[:, -1, :]
        next_tokens = torch.argmax(logits, dim=1)

        # static copy for CUDA graphs
        hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1))

        return next_tokens


@contextmanager
def replace_decoder(te_decoder_cls):
    """
    Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer`
    class.
    """
    original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer
    transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls
    try:
        yield
    finally:
        transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls


class TEGemmaForCausalLM(GemmaForCausalLM):
    """
    Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer`
    class is monkey-patched with `TEGemmaDecoderLayer` class before
    initializing the causal LM with `GemmaForCausalLM`.

    Args:
        config: Gemma model config that HF uses to initialize the model.
    """

    def __init__(self, config: GemmaConfig):

        dtype = torch.bfloat16
        with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
            super().__init__(config)

        self.config = config
        self.to(dtype).cuda()
        self.hidden_size = config.hidden_size

        self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head)

        self._model_generation_phase = GemmaGenerationWrapper(
            lm_head=self.lm_head,
            model=self.model,
            dtype=dtype,
        )

        if self.config.fp8:
            self.fp8_recipe = get_default_fp8_recipe()

        # Rotary position embedding remains the same for all the layers and so
        # created here. This makes it compatible with CUDA Graphs too.
        self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)(
            max_seq_len=self.config.max_position_embeddings
        ).cuda()

    @staticmethod
    def _padding_to_end(inputs, lengths, max_seq_len=None):
        """
        Gets the tensor with sequence padded from the beginning and
        updates it inplace to be padded from its end.

        Parameters
        ----------
        inputs : Tensor, tensor with shape [b, s] containing token numbers.
                 It's padded from the beggining.
        lengths: Tensor, tensor with shape [s] with lengths of the sequences.

        """
        max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len
        batch_size, max_seq_len = inputs.shape
        new_input_ids = inputs.clone()
        for i in range(batch_size):
            new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len]
            new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])]

        # Trim the inputs to no extra padding i.e. fix the max seq len to
        # the longest sequence in the batch
        actual_max_seq_len = max_seq_len
        inputs.data = new_input_ids[:, :actual_max_seq_len]

    def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor):
        """
        Returns a tensor of shape [b, s, hd] where `b` is the batch size,
        `s` is the sequence length, and `hd` is the hidden size.

        This function is overriden in TEGemmaForCausalLMCudaGraphs.
        """

        tensor = torch.empty(
            (input_ids.shape[0], input_ids.shape[1], self.hidden_size),
            device="cuda",
            dtype=torch.float32,
        )
        return tensor

    def _create_or_fetch_inference_params(self, *args, **kwargs):
        """
        Creates an InferenceParams object.

        This function is overriden in TEGemmaForCausalLMCudaGraphs.
        """

        infer_params = InferenceParams(*args, **kwargs)
        return infer_params

    def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None):
        """
        Returns a tensor of shape [b, 1, hd] where `b` is the batch size,
        `hd` is the hidden size.

        The buffer for generation is some part (beginning) of hidden states buffer.
        This function returns pointer to it and also copies there data if provided.
        """
        # hidden_states_buffer has shape [b, s, hd]
        # generation_buffer will have shape [b, 1, hd]
        # Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return
        # uncontiguous buffer, which we want to avoid.
        output = hidden_states_buffer.view(-1)[
            : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]
        ]
        if data_to_copy is not None:
            output.copy_(data_to_copy.reshape(-1))
        generation_buffer = output.view(
            (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])
        )
        return generation_buffer

    def setup_and_run_context_phase(
        self, input_ids: torch.Tensor, inference_params: InferenceParams
    ):
        """
        Runs the context or prefill phase of the model.

        This function is overriden in TEGemmaForCausalLMCudaGraphs.
        """

        hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids)
        hidden_states.copy_(self.model.embed_tokens(input_ids))

        # Update offsets before every forward pass (including context/prefill
        # phase) to make cache work properly.
        lengths = input_ids.ne(0).sum(dim=1)
        inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))

        logits = self._model_context_phase(
            hidden_states,
            attention_mask=None,
            attn_mask_type="padding_causal",
            rope_emb=self.te_rope_emb,
        )

        logits = logits[torch.arange(logits.size(0)), lengths - 1, :]
        next_tokens = torch.argmax(logits, dim=1)

        # `self.hidden_states` has shape [b, s, hd].
        # Return hidden state for the last token - output has shape [b, 1, hd].
        hidden_states = self._get_generation_buffer(
            hidden_states, self.model.embed_tokens(next_tokens)
        )
        return hidden_states, next_tokens

    @torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        pad_token_id: int = 0,
        max_new_tokens: int = 0,
        *args,
        **kwargs,
    ):
        """
        Generates next tokens auto-regressively for a batch of input tokens.
        """
        self.eval()

        # Both autocasts are needed: FP8 for operations that can run in lower
        # precision and BF16 for those that cannot.
        with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
            enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
        ):
            lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze()
            # If padding is at the beginning, then shift it to the end
            TEGemmaForCausalLM._padding_to_end(
                input_ids,
                lengths,
                max_seq_len=(
                    self.config.cuda_graphs_static_max_context_len
                    if self.config.generation_cuda_graphs
                    else None
                ),
            )

            batch_size = input_ids.shape[0]
            # For benchmark generation run, this is being set explicitly.
            max_input_sequence_len = self.config.max_seq_length

            # InferenceParams is a cache, where keys and values of previous
            # tokens are stored. Moreover it stores the current running lengths
            # of the sequences in the current batch.
            # A helper function is used to create the inference params object
            # because this `generate` method is common for TEGemmaForCausalLM
            # and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this
            # function is overriden to simply return the inference params object
            # that is already created in TEGemmaForCausalLMCudaGraphs'
            # constructor.
            inference_params = self._create_or_fetch_inference_params(
                max_batch_size=batch_size,
                max_sequence_length=max_input_sequence_len,
                num_heads_kv=self.config.num_key_value_heads,
                head_dim_v=self.config.head_dim,
                head_dim_k=self.config.head_dim,
                dtype=torch.bfloat16,
                is_paged=self.config.is_paged,
                page_size=16,
                total_num_pages=batch_size * max_input_sequence_len // 16,
            )

            # Set the inference params for both the context/prefill phase and
            # generation phase objects.
            self._model_context_phase.set_inference_params(inference_params)
            self._model_generation_phase.set_inference_params(inference_params)

            # Context/prefill phase.
            hidden_states, next_tokens = self.setup_and_run_context_phase(
                input_ids, inference_params
            )

            # Generation phase.
            lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
            inference_params.pre_step(
                OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
            )
            output_tokens = [next_tokens]

            for _ in range(max_new_tokens):
                next_tokens = self._model_generation_phase(
                    hidden_states,
                    mask=None,
                    attn_mask_type="padding",
                    rope_emb=self.te_rope_emb,
                )

                # Increase sequence offsets by one because we generated one token
                # for every sequence.
                lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
                inference_params.pre_step(
                    OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
                )

                # `next_tokens` is a static output tensor, so we need to clone
                # it because it gets changed every iteration.
                output_tokens.append(next_tokens.clone())

            result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
            return result

    def forward(self, *args, **kwargs):
        """
        Forward pass for the model. This is used in calibration step when
        forward pass is needed to generate FP8 calibration data.
        """

        self._model_context_phase.set_inference_params(None)
        hidden_states = self.model.embed_tokens(kwargs["input_ids"])
        logits = self._model_context_phase(
            hidden_states,
            attention_mask=(
                kwargs["input_ids"] == 0
            ),  # Hardcoded, this only applies to bshd/sbhd layouts.
            attn_mask_type="padding_causal",
        )
        return logits


class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
    """
    TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM
    and uses CUDA Graphs to speed up the generation process. We need to make one
    trade-off - batch_size, max_seq_len and max_context_seq_len need to
    be static. It is necessary to run generation without changing the pointer
    to the variables that are recorded in the graph.
    """

    def __init__(self, config: GemmaConfig):
        super().__init__(config)

        self.config = config

        # Preparation of the static buffer to hold the hidden states that are
        # passed from one layer to the next.
        self.hidden_states_buffer = torch.empty(
            (
                self.config.cuda_graphs_static_batch_size,
                self.config.cuda_graphs_static_max_context_len,
                self.config.hidden_size,
            )
        ).cuda()

        # This is in fact part of the buffer for hidden_states. Refer to the
        # `_get_generation_buffer` function for more details.
        self.generation_buffer = self._get_generation_buffer(
            self.hidden_states_buffer,
        )

        # InferenceParams contains the keys and values cache. Refer to the
        # original call in TEGemmaForCausalLM's `generate` method for more
        # details.
        self.inference_params = InferenceParams(
            max_batch_size=self.config.cuda_graphs_static_batch_size,
            max_sequence_length=self.config.cuda_graphs_static_max_context_len,
            num_heads_kv=self.config.num_key_value_heads,
            head_dim_v=self.config.head_dim,
            head_dim_k=self.config.head_dim,
            dtype=torch.bfloat16,
            is_paged=self.config.is_paged,
            page_size=16,
            total_num_pages=self.config.cuda_graphs_static_batch_size
            * self.config.cuda_graphs_static_max_context_len
            // 16,
        )

        self._model_generation_phase.set_inference_params(self.inference_params)
        self._model_context_phase.set_inference_params(self.inference_params)

    def record(self):
        """
        Here "the trick" happens. `_model_context_phase` and
        `_model_generation_phase` from TEGemmaForCausalLM are replaced with
        their recorded version. Once the graphs are recorded, they can be
        replayed with minimal usage of CPU and that leads to speedup.
        """
        # Record the model with training=False, because it will be used in
        # generation.
        self.eval()

        # Setup the recording for context/prefill phase.
        input_shape = (
            self.config.cuda_graphs_static_batch_size,
            self.config.cuda_graphs_static_max_context_len,
        )

        # Hardcoded value for the context length.
        lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to(
            device="cuda", dtype=torch.int32
        )
        self.inference_params.pre_step(
            OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
        )

        # Record the graph for context/prefill phase.
        self._model_context_phase = self.record_graph(
            self._model_context_phase,
            self.hidden_states_buffer,
            attn_mask_type="padding_causal",
            rope_emb=self.te_rope_emb,
        )

        # Setup the recording for generation phase.
        input_shape = (self.config.cuda_graphs_static_batch_size, 1)
        lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32)
        self.inference_params.pre_step(
            OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
        )

        # Record the graph for generation phase.
        self._model_generation_phase = self.record_graph(
            self._model_generation_phase,
            self.generation_buffer,
            attn_mask_type="padding",
            rope_emb=self.te_rope_emb,
        )

    def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs):
        """
        Overriden to make `hidden_states` static i.e. not change its pointer
        in memory between every invocation.

        Returns the static buffer for `hidden states` which is already created
        in the constructor. This is the same buffer as used in the
        context/prefill phase.
        """
        return self.hidden_states_buffer

    def _create_or_fetch_inference_params(self, *args, **kwargs):
        """
        Overriden to make `inference_params` static i.e. not change its pointer
        in memory between every invocation.

        Returns the static buffer for `inference_params` which is already created
        in the constructor.
        """
        self.inference_params.reset()
        return self.inference_params

    @torch.no_grad()
    def record_graph(self, function, input_tensor, **sample_kwargs):
        """
        Records the graph for the given function. The function is invoked on
        argument (self.hidden_states,) and all kernels are recorded.
        It then returns the captured callable, which can be run later while
        minimizing CPU usage.
        """
        fp8_recipe = get_default_fp8_recipe()

        # We need both autocasts: FP8 for operations that can run in lower
        # precision and BF16 for those that cannot.
        with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False):
            graphed_function = te.pytorch.make_graphed_callables(
                function,
                (input_tensor,),
                fp8_enabled=self.config.fp8,
                fp8_recipe=fp8_recipe,
                allow_unused_input=True,
                num_warmup_iters=5,
                sample_kwargs=sample_kwargs,
            )
        return graphed_function
