[core] add extra_args to SamplingParams (#13300)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
This commit is contained in:
Aviv Keshet 2025-03-07 22:41:18 -08:00 committed by GitHub
parent 9f3bc0f58c
commit 4aae667668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -184,6 +184,9 @@ class SamplingParams(
allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids.
Defaults to None.
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations. Not used by any in-tree sampling
implementations.
"""
n: int = 1
@ -227,6 +230,7 @@ class SamplingParams(
guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[dict[int, float]] = None
allowed_token_ids: Optional[list[int]] = None
extra_args: Optional[dict[str, Any]] = None
@staticmethod
def from_optional(
@ -259,6 +263,7 @@ class SamplingParams(
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
@ -300,6 +305,7 @@ class SamplingParams(
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
)
def __post_init__(self) -> None:
@ -509,7 +515,8 @@ class SamplingParams(
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding})")
f"guided_decoding={self.guided_decoding}, "
f"extra_args={self.extra_args})")
class BeamSearchParams(