Fix performance when --generation-config is not None (#14223)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-04 20:59:22 +01:00 committed by GitHub
parent beebf4742a
commit 9badee53de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 25 deletions

View File

@ -244,6 +244,7 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None
@staticmethod
def get_engine_class() -> type[LLMEngine]:
@ -268,10 +269,11 @@ class LLM:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
if self.default_sampling_params is None:
self.default_sampling_params = (
self.llm_engine.model_config.get_diff_sampling_param())
if self.default_sampling_params:
return SamplingParams.from_optional(**self.default_sampling_params)
return SamplingParams()
@overload

View File

@ -105,10 +105,11 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
self.default_sampling_params)
async def create_chat_completion(
self,
@ -210,17 +211,14 @@ class OpenAIServingChat(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, default_sampling_params)
default_max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)
self.default_sampling_params)
self._log_inputs(request_id,
request_prompts[i],

View File

@ -51,11 +51,12 @@ class OpenAIServingCompletion(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
self.default_sampling_params)
async def create_completion(
self,
@ -119,17 +120,14 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, default_sampling_params)
default_max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)
self.default_sampling_params)
request_id_item = f"{request_id}-{i}"

View File

@ -161,11 +161,12 @@ class OpenAIServingTranscription(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
self.default_sampling_params)
async def _preprocess_transcription(
self,
@ -273,9 +274,8 @@ class OpenAIServingTranscription(OpenAIServing):
try:
# TODO(rob): subtract len of tokenized prompt.
default_max_tokens = self.model_config.max_model_len
default_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens, default_params)
default_max_tokens, self.default_sampling_params)
self._log_inputs(
request_id,