bridge.models.hf_pretrained.causal_lm#

Module Contents#

Classes#

PreTrainedCausalLM

A generic class for Pretrained Causal Language Models with lazy loading.

GenerateKwargs

TypedDict for generate method parameters.

EncodeKwargs

TypedDict for encode method parameters.

DecodeKwargs

TypedDict for decode method parameters.

Data#

API#

bridge.models.hf_pretrained.causal_lm.CausalLMType#

‘TypeVar(…)’

class bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM(
model_name_or_path: Optional[Union[str, pathlib.Path]] = None,
device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
**kwargs,
)#

Bases: megatron.bridge.models.hf_pretrained.base.PreTrainedBase, typing.Generic[bridge.models.hf_pretrained.causal_lm.CausalLMType]

A generic class for Pretrained Causal Language Models with lazy loading.

Allows type-safe access to specific model implementations like LlamaForCausalLM.

.. rubric:: Examples

Basic usage with lazy loading:

from mbridge.pretrained import PreTrainedCausalLM

Create instance - no model loading happens yet

model = PreTrainedCausalLM.from_pretrained(“meta-llama/Llama-2-7b-chat-hf”)

Components are loaded on first access

config = model.config # Loads config tokenizer = model.tokenizer # Loads tokenizer

Generate text - model is loaded here

inputs = model.encode(“Hello, how are you?”) outputs = model.generate(**inputs, max_length=50) print(model.decode(outputs[0], skip_special_tokens=True))

Using specific model types with type hints:

from transformers import LlamaForCausalLM from mbridge.pretrained import PreTrainedCausalLM

Type-safe access to Llama-specific features

llama_model: PreTrainedCausalLM[LlamaForCausalLM] = PreTrainedCausalLM.from_pretrained( … “meta-llama/Llama-2-7b-chat-hf”, … torch_dtype=torch.float16, … device=”cuda” … )

Access Llama-specific attributes

model_instance = llama_model.model # Type is LlamaForCausalLM

Loading with custom configurations:

Load model with specific settings

model = PreTrainedCausalLM.from_pretrained( … “gpt2”, … device=”cuda:0”, … torch_dtype=torch.bfloat16, … attn_implementation=”flash_attention_2”, … load_in_8bit=True … )

Override generation config

from transformers import GenerationConfig model.generation_config = GenerationConfig( … max_length=100, … temperature=0.7, … top_p=0.9, … do_sample=True … )

Manual component management:

Create empty instance

model = PreTrainedCausalLM()

Manually set components

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM model.config = AutoConfig.from_pretrained(“microsoft/phi-2”) model.tokenizer = AutoTokenizer.from_pretrained(“microsoft/phi-2”) model.model = AutoModelForCausalLM.from_pretrained(“microsoft/phi-2”)

Save all components

model.save_artifacts(“./my_model”)

Batch processing example:

Process multiple prompts

prompts = [ … “The capital of France is”, … “Machine learning is”, … “Python programming language was created by” … ]

Encode all prompts

inputs = model.encode(prompts, padding=True, truncation=True)

Generate completions

outputs = model.generate(**inputs, max_new_tokens=20)

Decode results

for i, output in enumerate(outputs): … print(f”Prompt {i+1}: {model.decode(output, skip_special_tokens=True)}”)

Initialization

Initialize a Pretrained Causal LM with lazy loading.

Parameters:
  • model_name_or_path – HuggingFace model identifier or local path

  • device – Device to load model on (e.g., ‘cuda’, ‘cpu’)

  • torch_dtype – Data type to load model in (e.g., torch.float16)

  • trust_remote_code – Whether to trust remote code when loading

  • **kwargs – Additional arguments passed to from_pretrained methods

ARTIFACTS#

[‘tokenizer’]

OPTIONAL_ARTIFACTS#

[‘generation_config’]

_load_model() bridge.models.hf_pretrained.causal_lm.CausalLMType#

Load the model.

_load_config() transformers.AutoConfig#

Load the model config.

_load_tokenizer() transformers.PreTrainedTokenizer#

Load the tokenizer.

_load_generation_config() Optional[transformers.GenerationConfig]#

Load the generation config.

property generation_config: Optional[transformers.GenerationConfig]#

Lazy load and return the generation config.

property tokenizer: transformers.PreTrainedTokenizer#

Lazy load and return the tokenizer.

property model_name_or_path: Optional[Union[str, pathlib.Path]]#

Return the model name or path.

property has_model: bool#

Check if model has been loaded.

property model: bridge.models.hf_pretrained.causal_lm.CausalLMType#

Lazy load and return the underlying model.

classmethod from_pretrained(
model_name_or_path: Union[str, pathlib.Path],
device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
**kwargs,
) PreTrainedCausalLM[CausalLMType]#

Create a PreTrainedCausalLM instance for lazy loading.

Parameters:
  • model_name_or_path – HuggingFace model identifier or local path

  • device – Device to load model on

  • torch_dtype – Data type to load model in

  • trust_remote_code – Whether to trust remote code

  • **kwargs – Additional arguments for from_pretrained methods

Returns:

PreTrainedCausalLM instance configured for lazy loading

generate(
input_ids: Optional[torch.LongTensor] = None,
**kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.GenerateKwargs],
) Union[torch.LongTensor, transformers.generation.utils.GenerateOutput]#

Generate text using the underlying language model.

This method forwards all arguments to the model’s generate method, supporting all generation strategies provided by the transformers library.

Common parameters include: inputs (torch.LongTensor, optional): Input token IDs. If not provided, will generate from the beginning of sequence token. max_length (int, optional): Maximum length of generated sequence. Defaults to model’s max_length configuration. min_length (int, optional): Minimum length of generated sequence. max_new_tokens (int, optional): Maximum number of tokens to generate, ignoring the number of tokens in the prompt. do_sample (bool, optional): Whether to use sampling. Defaults to False (greedy decoding). temperature (float, optional): Temperature for sampling. Higher values produce more random outputs. Typical range: 0.1-2.0. top_p (float, optional): Nucleus sampling threshold. Only tokens with cumulative probability up to top_p are considered. Range: 0.0-1.0. top_k (int, optional): Only consider the top k tokens for sampling. num_beams (int, optional): Number of beams for beam search. 1 means no beam search. repetition_penalty (float, optional): Penalty for repeating tokens. Values > 1.0 discourage repetition. pad_token_id (int, optional): ID of padding token. eos_token_id (int or List[int], optional): ID(s) of end-of-sequence token(s). use_cache (bool, optional): Whether to use past key values to speed up generation. Defaults to True.

Returns:

Generated token IDs. If return_dict_in_generate=True, returns a GenerateOutput object containing generated sequences and additional information like scores.

Return type:

torch.LongTensor or transformers.generation.utils.GenerateOutput

.. rubric:: Examples

Basic generation

model = PreTrainedCausalLM.from_pretrained(“gpt2”) inputs = model.encode(“Hello, how are”) outputs = model.generate(inputs[“input_ids”], max_length=20) print(model.decode(outputs[0]))

Generation with sampling

outputs = model.generate( … inputs[“input_ids”], … max_length=50, … do_sample=True, … temperature=0.8, … top_p=0.9 … )

outputs = model.generate( … inputs[“input_ids”], … max_length=50, … num_beams=5, … early_stopping=True … )

.. note::

For detailed documentation of all parameters, see the transformers library documentation for generation methods.

__call__(*args, **kwargs)#

Forward call to model.

encode(
text: Union[str, List[str]],
**kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.EncodeKwargs],
) Dict[str, torch.Tensor]#

Encode text into token IDs using the model’s tokenizer.

This method tokenizes input text and returns tensors ready for model input. The output is automatically moved to the same device as the model.

Parameters:
  • text (str or List[str]) – Input text to encode. Can be a single string or a list of strings for batch encoding.

  • **kwargs – Additional arguments passed to the tokenizer. Common options: padding (bool or str, optional): Padding strategy. - True or ‘longest’: Pad to longest sequence in batch - ‘max_length’: Pad to max_length - False or ‘do_not_pad’: No padding (default) truncation (bool or str, optional): Truncation strategy. - True or ‘longest_first’: Truncate to max_length - ‘only_first’: Truncate only first sequence (for pairs) - False: No truncation max_length (int, optional): Maximum length of returned sequences. Defaults to model’s max_length. add_special_tokens (bool, optional): Whether to add special tokens (e.g., [CLS], [SEP]). Defaults to True. return_attention_mask (bool, optional): Whether to return attention mask. Defaults to True. return_token_type_ids (bool, optional): Whether to return token type IDs (for models like BERT). Defaults to True if model expects them.

Returns:

Dictionary containing: - input_ids: Token IDs tensor of shape (batch_size, sequence_length) - attention_mask: Attention mask tensor of same shape (if applicable) - token_type_ids: Token type IDs tensor (if applicable) Additional keys may be present depending on the tokenizer.

Return type:

Dict[str, torch.Tensor]

.. rubric:: Examples

model = PreTrainedCausalLM.from_pretrained(“gpt2”)

Single text encoding

tokens = model.encode(“Hello world!”) print(tokens[“input_ids”].shape) # torch.Size([1, 3])

Batch encoding with padding

texts = [“Hello!”, “How are you doing today?”] tokens = model.encode(texts, padding=True) print(tokens[“input_ids”].shape) # torch.Size([2, 6])

Encoding with truncation

tokens = model.encode( … “This is a very long text that might exceed the maximum length”, … truncation=True, … max_length=10 … )

.. note::

The returned tensors are on the same device as the model, ready for immediate use in forward passes or generation.

decode(
token_ids: Union[int, List[int], torch.Tensor],
**kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.DecodeKwargs],
) str#

Decode token IDs back into text using the model’s tokenizer.

This method converts token IDs (from model output or encode method) back into human-readable text.

Parameters:
  • token_ids (int, List[int], or torch.Tensor) –

    Token IDs to decode. Can be:

    • Single token ID (int)

    • List of token IDs

    • 1D tensor of token IDs

    • 2D tensor (will decode the first sequence)

  • **kwargs – Additional arguments passed to the tokenizer’s decode method: skip_special_tokens (bool, optional): Whether to remove special tokens (e.g., [PAD], [CLS], [SEP]) from output. Defaults to True. clean_up_tokenization_spaces (bool, optional): Whether to clean up tokenization artifacts (extra spaces, etc.). Defaults to True.

Returns:

Decoded text string.

Return type:

str

.. rubric:: Examples

model = PreTrainedCausalLM.from_pretrained(“gpt2”)

Encode and decode round-trip

text = “Hello, world!” tokens = model.encode(text) decoded = model.decode(tokens[“input_ids”][0]) print(decoded) # “Hello, world!”

Decode generated tokens

inputs = model.encode(“The weather is”) outputs = model.generate(inputs[“input_ids”], max_length=10) decoded = model.decode(outputs[0]) print(decoded) # “The weather is nice today…”

Decode without special tokens

token_ids = [101, 7592, 1010, 2088, 999, 102] # BERT-style tokens decoded = model.decode(token_ids, skip_special_tokens=True) print(decoded) # “Hello, world!”

Decode keeping special tokens

decoded = model.decode(token_ids, skip_special_tokens=False) print(decoded) # “[CLS] Hello, world! [SEP]”

.. note::

If a 2D tensor is provided (batch of sequences), only the first sequence is decoded. For batch decoding, use tokenizer.batch_decode() directly or iterate over the sequences.

to(device: Union[str, torch.device])#

Move model to specified device.

half()#

Convert model to half precision (float16).

float()#

Convert model to full precision (float32).

save_pretrained(save_directory: Union[str, pathlib.Path])#

Save all components (model, tokenizer, config, generation_config) to a directory.

This method saves:

  • Model weights and config

  • Tokenizer files

  • Generation config (if available)

Parameters:

save_directory – Path to directory where components will be saved

property dtype: Optional[torch.dtype]#

Get model’s dtype if loaded.

property num_parameters: Optional[int]#

Get total number of parameters if model is loaded.

__repr__() str#

Return a string representation of the PreTrainedCausalLM instance.

class bridge.models.hf_pretrained.causal_lm.GenerateKwargs#

Bases: typing.TypedDict

TypedDict for generate method parameters.

Initialization

Initialize self. See help(type(self)) for accurate signature.

attention_mask: Optional[torch.Tensor]#

None

max_length: Optional[int]#

None

max_new_tokens: Optional[int]#

None

min_length: Optional[int]#

None

do_sample: Optional[bool]#

None

temperature: Optional[float]#

None

top_k: Optional[int]#

None

top_p: Optional[float]#

None

repetition_penalty: Optional[float]#

None

pad_token_id: Optional[int]#

None

eos_token_id: Optional[Union[int, List[int]]]#

None

bos_token_id: Optional[int]#

None

num_beams: Optional[int]#

None

num_return_sequences: Optional[int]#

None

early_stopping: Optional[bool]#

None

use_cache: Optional[bool]#

None

return_dict_in_generate: Optional[bool]#

None

output_scores: Optional[bool]#

None

output_attentions: Optional[bool]#

None

class bridge.models.hf_pretrained.causal_lm.EncodeKwargs#

Bases: typing.TypedDict

TypedDict for encode method parameters.

Initialization

Initialize self. See help(type(self)) for accurate signature.

padding: Union[bool, str]#

None

truncation: Union[bool, str]#

None

max_length: Optional[int]#

None

add_special_tokens: bool#

None

return_attention_mask: bool#

None

return_token_type_ids: Optional[bool]#

None

return_tensors: str#

None

class bridge.models.hf_pretrained.causal_lm.DecodeKwargs#

Bases: typing.TypedDict

TypedDict for decode method parameters.

Initialization

Initialize self. See help(type(self)) for accurate signature.

skip_special_tokens: bool#

None

clean_up_tokenization_spaces: bool#

None