diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 122e2ed86cb64..fc585ee9e54b9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -244,6 +244,7 @@ class LLM: engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() + self.default_sampling_params: Union[dict[str, Any], None] = None @staticmethod def get_engine_class() -> type[LLMEngine]: @@ -268,10 +269,11 @@ class LLM: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) def get_default_sampling_params(self) -> SamplingParams: - diff_sampling_param = ( - self.llm_engine.model_config.get_diff_sampling_param()) - if diff_sampling_param: - return SamplingParams.from_optional(**diff_sampling_param) + if self.default_sampling_params is None: + self.default_sampling_params = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() @overload diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 98e9ea0fc61a2..f4aaee3607803 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -105,10 +105,11 @@ class OpenAIServingChat(OpenAIServing): "been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details - diff_sampling_param = self.model_config.get_diff_sampling_param() - if diff_sampling_param: + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: logger.info("Overwriting default chat sampling param with: %s", - diff_sampling_param) + self.default_sampling_params) async def create_chat_completion( self, @@ -210,17 +211,14 @@ class OpenAIServingChat(OpenAIServing): sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) - # Build default sampling params - default_sampling_params = ( - self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens, default_sampling_params) + default_max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( default_max_tokens, self.model_config.logits_processor_pattern, - default_sampling_params) + self.default_sampling_params) self._log_inputs(request_id, request_prompts[i], diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ed09af84f64ba..b2ad28c0a33cd 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -51,11 +51,12 @@ class OpenAIServingCompletion(OpenAIServing): models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) - diff_sampling_param = self.model_config.get_diff_sampling_param() - if diff_sampling_param: + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", - diff_sampling_param) + self.default_sampling_params) async def create_completion( self, @@ -119,17 +120,14 @@ class OpenAIServingCompletion(OpenAIServing): sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) - # Build default sampling params - default_sampling_params = ( - self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens, default_sampling_params) + default_max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( default_max_tokens, self.model_config.logits_processor_pattern, - default_sampling_params) + self.default_sampling_params) request_id_item = f"{request_id}-{i}" diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 77f016a5e0a4a..402a0bb7a6b0d 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -161,11 +161,12 @@ class OpenAIServingTranscription(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) - diff_sampling_param = self.model_config.get_diff_sampling_param() - if diff_sampling_param: + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", - diff_sampling_param) + self.default_sampling_params) async def _preprocess_transcription( self, @@ -273,9 +274,8 @@ class OpenAIServingTranscription(OpenAIServing): try: # TODO(rob): subtract len of tokenized prompt. default_max_tokens = self.model_config.max_model_len - default_params = self.model_config.get_diff_sampling_param() sampling_params = request.to_sampling_params( - default_max_tokens, default_params) + default_max_tokens, self.default_sampling_params) self._log_inputs( request_id,