[Performance] Cache loaded custom logitsprocs to avoid overheads (#28462)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-12 08:49:29 +08:00 committed by GitHub
parent 48c879369f
commit 3f770f4427
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import inspect
import itertools
from abc import abstractmethod
from collections.abc import Sequence
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING
import torch
@ -216,11 +216,17 @@ def build_logitsprocs(
)
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)