Improve documentation of ModelConfig.try_get_generation_config to prevent future confusion (#21526)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-01 17:32:27 +01:00 committed by GitHub
parent 2d7b09b998
commit 326a1b001d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1575,7 +1575,18 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config in ("auto", "vllm"):
"""
This method attempts to retrieve the non-default values of the
generation config for this model.
The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`.
Returns:
A dictionary containing the non-default generation config.
"""
if self.generation_config in {"auto", "vllm"}:
config = try_get_generation_config(
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
@ -1594,13 +1605,18 @@ class ModelConfig:
def get_diff_sampling_param(self) -> dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters. If
`generation_config` is `"vllm"`, an empty dictionary is returned.
This method returns a dictionary containing the non-default sampling
parameters with `override_generation_config` applied.
The default sampling parameters are:
- vLLM's neutral defaults if `self.generation_config="vllm"`
- the model's defaults if `self.generation_config="auto"`
- as defined in `generation_config.json` if
`self.generation_config="path/to/generation_config/dir"`
Returns:
dict[str, Any]: A dictionary with the differing sampling
parameters, if `generation_config` is `"vllm"` an empty dictionary.
A dictionary containing the non-default sampling parameters.
"""
if self.generation_config == "vllm":
config = {}