core.inference.sampling_params#
Module Contents#
Classes#
Inference parameters sent along with the prompts. This class contains request-level attributes that control the sampling techniques used when generating text. This is distinct from megatron.core.inference.contexts.BaseInferenceContext, which is sets model-level inference attributes such as the maximum sequence length, and contains the KV cache. |
API#
- class core.inference.sampling_params.SamplingParams#
Inference parameters sent along with the prompts. This class contains request-level attributes that control the sampling techniques used when generating text. This is distinct from megatron.core.inference.contexts.BaseInferenceContext, which is sets model-level inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- temperature-parameters-ed6a31313910
- temperature: float#
1.0
- top_k: int#
0
- top_p: float#
0.0
- return_log_probs: bool#
False
- skip_prompt_log_probs: bool#
False
- return_segments: bool#
False
- num_tokens_to_generate: int#
30
- num_tokens_total: Optional[int]#
None
- termination_id: Optional[int]#
None
- top_n_logprobs: int#
0
- return_prompt_top_n_logprobs: bool#
False
- add_BOS: bool#
False
- __post_init__()#
Ensure backward compatibility for return_prompt_top_n_logprobs.
Sets return_prompt_top_n_logprobs based on skip_prompt_log_probs and top_n_logprobs:
return_prompt_top_n_logprobs = not skip_prompt_log_probs and top_n_logprobs > 0
- _sync_prompt_logprobs_fields()#
Synchronize return_prompt_top_n_logprobs with skip_prompt_log_probs.
- add_attributes(attribute_value_pair: dict)#
Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes. c = SamplingParams c.add_attributes({‘min_length’:4, ‘eod_id’:153})
- Parameters:
attribute_value_pair (dict) – A dictionary containing attributes as the key names and
values. (their values as the)
- serialize() dict#
Return a dictionary that is msgpack-serializable.
- classmethod deserialize(
- data: dict,
Construct SamplingParams from a msgpack-compatible dictionary.