core.inference.sampling_params#

Module Contents#

Classes#

SamplingParams

Inference parameters sent along with the prompts. This class contains request-level attributes that control the sampling techniques used when generating text. This is distinct from megatron.core.inference.contexts.BaseInferenceContext, which is sets model-level inference attributes such as the maximum sequence length, and contains the KV cache.

API#

class core.inference.sampling_params.SamplingParams#

Inference parameters sent along with the prompts. This class contains request-level attributes that control the sampling techniques used when generating text. This is distinct from megatron.core.inference.contexts.BaseInferenceContext, which is sets model-level inference attributes such as the maximum sequence length, and contains the KV cache.

For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- temperature-parameters-ed6a31313910

temperature: float#

1.0

top_k: int#

0

top_p: float#

0.0

return_log_probs: bool#

False

skip_prompt_log_probs: bool#

False

return_segments: bool#

False

num_tokens_to_generate: int#

30

num_tokens_total: Optional[int]#

None

termination_id: Optional[int]#

None

top_n_logprobs: int#

0

return_prompt_top_n_logprobs: bool#

False

add_BOS: bool#

False

__post_init__()#

Ensure backward compatibility for return_prompt_top_n_logprobs.

Sets return_prompt_top_n_logprobs based on skip_prompt_log_probs and top_n_logprobs:

  • return_prompt_top_n_logprobs = not skip_prompt_log_probs and top_n_logprobs > 0

_sync_prompt_logprobs_fields()#

Synchronize return_prompt_top_n_logprobs with skip_prompt_log_probs.

add_attributes(attribute_value_pair: dict)#

Utility to add more attributes to sampling params

Use this method to pass in a custom dictionary to add more sampling parameter attributes. c = SamplingParams c.add_attributes({‘min_length’:4, ‘eod_id’:153})

Parameters:
  • attribute_value_pair (dict) – A dictionary containing attributes as the key names and

  • values. (their values as the)

serialize() dict#

Return a dictionary that is msgpack-serializable.

classmethod deserialize(
data: dict,
) core.inference.sampling_params.SamplingParams#

Construct SamplingParams from a msgpack-compatible dictionary.