From 44554a00685b7e60eeed5883986749405e538806 Mon Sep 17 00:00:00 2001 From: Wang Yijun Date: Tue, 22 Jul 2025 23:24:00 +0800 Subject: [PATCH] Add tokenization_kwargs to encode for embedding model truncation (#21033) --- vllm/engine/async_llm_engine.py | 6 ++++++ vllm/entrypoints/llm.py | 15 ++++++++++++--- vllm/v1/engine/async_llm.py | 2 ++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3d7d28055dd00..06ae2a2f18f27 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine): prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Async version of @@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine): prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, + tokenization_kwargs=tokenization_kwargs, ) if isinstance(params, SamplingParams) and \ @@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient): prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: @@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient): prompt_adapter_request=prompt_adapter_request, priority=priority, data_parallel_rank=data_parallel_rank, + tokenization_kwargs=tokenization_kwargs, ) return stream.generator() @@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model. @@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient): lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, ): yield LLMEngine.validate_output(output, PoolingRequestOutput) except asyncio.CancelledError: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 78f9d32d811d3..c4f1b3b86619a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -965,6 +965,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -981,6 +982,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -997,6 +999,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1014,6 +1017,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1031,6 +1035,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1046,6 +1051,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1066,6 +1072,7 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1131,9 +1138,11 @@ class LLM: for pooling_param in pooling_params: pooling_param.verify(pooling_task, model_config) - tokenization_kwargs = dict[str, Any]() - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) + if tokenization_kwargs is None: + tokenization_kwargs = dict[str, Any]() + _validate_truncation_size(model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs) self._validate_and_add_requests( prompts=parsed_prompts, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b8ba36f3502f7..79b5d5ae4a23e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -437,6 +437,7 @@ class AsyncLLM(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ Main function called by the API server to kick off a request @@ -465,6 +466,7 @@ class AsyncLLM(EngineClient): lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, ) # The output_handler task pushes items into the queue.