bridge.models.hf_pretrained.causal_lm
#
Module Contents#
Classes#
A generic class for Pretrained Causal Language Models with lazy loading. |
|
TypedDict for generate method parameters. |
|
TypedDict for encode method parameters. |
|
TypedDict for decode method parameters. |
Data#
API#
- bridge.models.hf_pretrained.causal_lm.CausalLMType#
‘TypeVar(…)’
- class bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM(
- model_name_or_path: Optional[Union[str, pathlib.Path]] = None,
- device: Optional[Union[str, torch.device]] = None,
- torch_dtype: Optional[torch.dtype] = None,
- trust_remote_code: bool = False,
- **kwargs,
Bases:
megatron.bridge.models.hf_pretrained.base.PreTrainedBase
,typing.Generic
[bridge.models.hf_pretrained.causal_lm.CausalLMType
]A generic class for Pretrained Causal Language Models with lazy loading.
Allows type-safe access to specific model implementations like LlamaForCausalLM.
.. rubric:: Examples
Basic usage with lazy loading:
from mbridge.pretrained import PreTrainedCausalLM
Create instance - no model loading happens yet
model = PreTrainedCausalLM.from_pretrained(“meta-llama/Llama-2-7b-chat-hf”)
Components are loaded on first access
config = model.config # Loads config tokenizer = model.tokenizer # Loads tokenizer
Generate text - model is loaded here
inputs = model.encode(“Hello, how are you?”) outputs = model.generate(**inputs, max_length=50) print(model.decode(outputs[0], skip_special_tokens=True))
Using specific model types with type hints:
from transformers import LlamaForCausalLM from mbridge.pretrained import PreTrainedCausalLM
Type-safe access to Llama-specific features
llama_model: PreTrainedCausalLM[LlamaForCausalLM] = PreTrainedCausalLM.from_pretrained( … “meta-llama/Llama-2-7b-chat-hf”, … torch_dtype=torch.float16, … device=”cuda” … )
Access Llama-specific attributes
model_instance = llama_model.model # Type is LlamaForCausalLM
Loading with custom configurations:
Load model with specific settings
model = PreTrainedCausalLM.from_pretrained( … “gpt2”, … device=”cuda:0”, … torch_dtype=torch.bfloat16, … attn_implementation=”flash_attention_2”, … load_in_8bit=True … )
Override generation config
from transformers import GenerationConfig model.generation_config = GenerationConfig( … max_length=100, … temperature=0.7, … top_p=0.9, … do_sample=True … )
Manual component management:
Create empty instance
model = PreTrainedCausalLM()
Manually set components
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM model.config = AutoConfig.from_pretrained(“microsoft/phi-2”) model.tokenizer = AutoTokenizer.from_pretrained(“microsoft/phi-2”) model.model = AutoModelForCausalLM.from_pretrained(“microsoft/phi-2”)
Save all components
model.save_artifacts(“./my_model”)
Batch processing example:
Process multiple prompts
prompts = [ … “The capital of France is”, … “Machine learning is”, … “Python programming language was created by” … ]
Encode all prompts
inputs = model.encode(prompts, padding=True, truncation=True)
Generate completions
outputs = model.generate(**inputs, max_new_tokens=20)
Decode results
for i, output in enumerate(outputs): … print(f”Prompt {i+1}: {model.decode(output, skip_special_tokens=True)}”)
Initialization
Initialize a Pretrained Causal LM with lazy loading.
- Parameters:
model_name_or_path – HuggingFace model identifier or local path
device – Device to load model on (e.g., ‘cuda’, ‘cpu’)
torch_dtype – Data type to load model in (e.g., torch.float16)
trust_remote_code – Whether to trust remote code when loading
**kwargs – Additional arguments passed to from_pretrained methods
- ARTIFACTS#
[‘tokenizer’]
- OPTIONAL_ARTIFACTS#
[‘generation_config’]
- _load_model() bridge.models.hf_pretrained.causal_lm.CausalLMType #
Load the model.
- _load_config() transformers.AutoConfig #
Load the model config.
- _load_tokenizer() transformers.PreTrainedTokenizer #
Load the tokenizer.
- _load_generation_config() Optional[transformers.GenerationConfig] #
Load the generation config.
- property generation_config: Optional[transformers.GenerationConfig]#
Lazy load and return the generation config.
- property tokenizer: transformers.PreTrainedTokenizer#
Lazy load and return the tokenizer.
- property model_name_or_path: Optional[Union[str, pathlib.Path]]#
Return the model name or path.
- property has_model: bool#
Check if model has been loaded.
- property model: bridge.models.hf_pretrained.causal_lm.CausalLMType#
Lazy load and return the underlying model.
- classmethod from_pretrained(
- model_name_or_path: Union[str, pathlib.Path],
- device: Optional[Union[str, torch.device]] = None,
- torch_dtype: Optional[torch.dtype] = None,
- trust_remote_code: bool = False,
- **kwargs,
Create a PreTrainedCausalLM instance for lazy loading.
- Parameters:
model_name_or_path – HuggingFace model identifier or local path
device – Device to load model on
torch_dtype – Data type to load model in
trust_remote_code – Whether to trust remote code
**kwargs – Additional arguments for from_pretrained methods
- Returns:
PreTrainedCausalLM instance configured for lazy loading
- generate(
- input_ids: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.GenerateKwargs],
Generate text using the underlying language model.
This method forwards all arguments to the model’s generate method, supporting all generation strategies provided by the transformers library.
Common parameters include: inputs (torch.LongTensor, optional): Input token IDs. If not provided, will generate from the beginning of sequence token. max_length (int, optional): Maximum length of generated sequence. Defaults to model’s max_length configuration. min_length (int, optional): Minimum length of generated sequence. max_new_tokens (int, optional): Maximum number of tokens to generate, ignoring the number of tokens in the prompt. do_sample (bool, optional): Whether to use sampling. Defaults to False (greedy decoding). temperature (float, optional): Temperature for sampling. Higher values produce more random outputs. Typical range: 0.1-2.0. top_p (float, optional): Nucleus sampling threshold. Only tokens with cumulative probability up to top_p are considered. Range: 0.0-1.0. top_k (int, optional): Only consider the top k tokens for sampling. num_beams (int, optional): Number of beams for beam search. 1 means no beam search. repetition_penalty (float, optional): Penalty for repeating tokens. Values > 1.0 discourage repetition. pad_token_id (int, optional): ID of padding token. eos_token_id (int or List[int], optional): ID(s) of end-of-sequence token(s). use_cache (bool, optional): Whether to use past key values to speed up generation. Defaults to True.
- Returns:
Generated token IDs. If return_dict_in_generate=True, returns a GenerateOutput object containing generated sequences and additional information like scores.
- Return type:
torch.LongTensor or transformers.generation.utils.GenerateOutput
.. rubric:: Examples
Basic generation
model = PreTrainedCausalLM.from_pretrained(“gpt2”) inputs = model.encode(“Hello, how are”) outputs = model.generate(inputs[“input_ids”], max_length=20) print(model.decode(outputs[0]))
Generation with sampling
outputs = model.generate( … inputs[“input_ids”], … max_length=50, … do_sample=True, … temperature=0.8, … top_p=0.9 … )
Beam search
outputs = model.generate( … inputs[“input_ids”], … max_length=50, … num_beams=5, … early_stopping=True … )
.. note::
For detailed documentation of all parameters, see the transformers library documentation for generation methods.
- __call__(*args, **kwargs)#
Forward call to model.
- encode(
- text: Union[str, List[str]],
- **kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.EncodeKwargs],
Encode text into token IDs using the model’s tokenizer.
This method tokenizes input text and returns tensors ready for model input. The output is automatically moved to the same device as the model.
- Parameters:
text (str or List[str]) – Input text to encode. Can be a single string or a list of strings for batch encoding.
**kwargs – Additional arguments passed to the tokenizer. Common options: padding (bool or str, optional): Padding strategy. - True or ‘longest’: Pad to longest sequence in batch - ‘max_length’: Pad to max_length - False or ‘do_not_pad’: No padding (default) truncation (bool or str, optional): Truncation strategy. - True or ‘longest_first’: Truncate to max_length - ‘only_first’: Truncate only first sequence (for pairs) - False: No truncation max_length (int, optional): Maximum length of returned sequences. Defaults to model’s max_length. add_special_tokens (bool, optional): Whether to add special tokens (e.g., [CLS], [SEP]). Defaults to True. return_attention_mask (bool, optional): Whether to return attention mask. Defaults to True. return_token_type_ids (bool, optional): Whether to return token type IDs (for models like BERT). Defaults to True if model expects them.
- Returns:
Dictionary containing: - input_ids: Token IDs tensor of shape (batch_size, sequence_length) - attention_mask: Attention mask tensor of same shape (if applicable) - token_type_ids: Token type IDs tensor (if applicable) Additional keys may be present depending on the tokenizer.
- Return type:
Dict[str, torch.Tensor]
.. rubric:: Examples
model = PreTrainedCausalLM.from_pretrained(“gpt2”)
Single text encoding
tokens = model.encode(“Hello world!”) print(tokens[“input_ids”].shape) # torch.Size([1, 3])
Batch encoding with padding
texts = [“Hello!”, “How are you doing today?”] tokens = model.encode(texts, padding=True) print(tokens[“input_ids”].shape) # torch.Size([2, 6])
Encoding with truncation
tokens = model.encode( … “This is a very long text that might exceed the maximum length”, … truncation=True, … max_length=10 … )
.. note::
The returned tensors are on the same device as the model, ready for immediate use in forward passes or generation.
- decode(
- token_ids: Union[int, List[int], torch.Tensor],
- **kwargs: Unpack[bridge.models.hf_pretrained.causal_lm.DecodeKwargs],
Decode token IDs back into text using the model’s tokenizer.
This method converts token IDs (from model output or encode method) back into human-readable text.
- Parameters:
token_ids (int, List[int], or torch.Tensor) –
Token IDs to decode. Can be:
Single token ID (int)
List of token IDs
1D tensor of token IDs
2D tensor (will decode the first sequence)
**kwargs – Additional arguments passed to the tokenizer’s decode method: skip_special_tokens (bool, optional): Whether to remove special tokens (e.g., [PAD], [CLS], [SEP]) from output. Defaults to True. clean_up_tokenization_spaces (bool, optional): Whether to clean up tokenization artifacts (extra spaces, etc.). Defaults to True.
- Returns:
Decoded text string.
- Return type:
str
.. rubric:: Examples
model = PreTrainedCausalLM.from_pretrained(“gpt2”)
Encode and decode round-trip
text = “Hello, world!” tokens = model.encode(text) decoded = model.decode(tokens[“input_ids”][0]) print(decoded) # “Hello, world!”
Decode generated tokens
inputs = model.encode(“The weather is”) outputs = model.generate(inputs[“input_ids”], max_length=10) decoded = model.decode(outputs[0]) print(decoded) # “The weather is nice today…”
Decode without special tokens
token_ids = [101, 7592, 1010, 2088, 999, 102] # BERT-style tokens decoded = model.decode(token_ids, skip_special_tokens=True) print(decoded) # “Hello, world!”
Decode keeping special tokens
decoded = model.decode(token_ids, skip_special_tokens=False) print(decoded) # “[CLS] Hello, world! [SEP]”
.. note::
If a 2D tensor is provided (batch of sequences), only the first sequence is decoded. For batch decoding, use tokenizer.batch_decode() directly or iterate over the sequences.
- to(device: Union[str, torch.device])#
Move model to specified device.
- half()#
Convert model to half precision (float16).
- float()#
Convert model to full precision (float32).
- save_pretrained(save_directory: Union[str, pathlib.Path])#
Save all components (model, tokenizer, config, generation_config) to a directory.
This method saves:
Model weights and config
Tokenizer files
Generation config (if available)
- Parameters:
save_directory – Path to directory where components will be saved
- property dtype: Optional[torch.dtype]#
Get model’s dtype if loaded.
- property num_parameters: Optional[int]#
Get total number of parameters if model is loaded.
- __repr__() str #
Return a string representation of the PreTrainedCausalLM instance.
- class bridge.models.hf_pretrained.causal_lm.GenerateKwargs#
Bases:
typing.TypedDict
TypedDict for generate method parameters.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- attention_mask: Optional[torch.Tensor]#
None
- max_length: Optional[int]#
None
- max_new_tokens: Optional[int]#
None
- min_length: Optional[int]#
None
- do_sample: Optional[bool]#
None
- temperature: Optional[float]#
None
- top_k: Optional[int]#
None
- top_p: Optional[float]#
None
- repetition_penalty: Optional[float]#
None
- pad_token_id: Optional[int]#
None
- eos_token_id: Optional[Union[int, List[int]]]#
None
- bos_token_id: Optional[int]#
None
- num_beams: Optional[int]#
None
- num_return_sequences: Optional[int]#
None
- early_stopping: Optional[bool]#
None
- use_cache: Optional[bool]#
None
- return_dict_in_generate: Optional[bool]#
None
- output_scores: Optional[bool]#
None
- output_attentions: Optional[bool]#
None
- class bridge.models.hf_pretrained.causal_lm.EncodeKwargs#
Bases:
typing.TypedDict
TypedDict for encode method parameters.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- padding: Union[bool, str]#
None
- truncation: Union[bool, str]#
None
- max_length: Optional[int]#
None
- add_special_tokens: bool#
None
- return_attention_mask: bool#
None
- return_token_type_ids: Optional[bool]#
None
- return_tensors: str#
None