Support tokenization_kwargs override (#29794)

Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
Yu Jiaqi 2025-12-06 17:12:53 +08:00 committed by GitHub
parent c46b932df2
commit 43e7593031
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 8 deletions

View File

@ -405,6 +405,7 @@ class HfRunner:
images: PromptImageInput | None = None, images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None, videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None, audios: PromptAudioInput | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]: ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
@ -418,10 +419,18 @@ class HfRunner:
all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = [] all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if isinstance(prompt, str): if isinstance(prompt, str):
processor_kwargs: dict[str, Any] = { # Create a copy to avoid modifying the original dict
"text": prompt, processor_kwargs = (
"return_tensors": "pt", tokenization_kwargs.copy()
} if tokenization_kwargs is not None
else {}
)
processor_kwargs.update(
{
"text": prompt,
"return_tensors": "pt",
}
)
if images is not None and (image := images[i]) is not None: if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None: if videos is not None and (video := videos[i]) is not None:

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest import pytest
from transformers import SiglipModel from transformers import SiglipModel
@ -35,7 +37,11 @@ def _run_test(
model: str, model: str,
*, *,
dtype: str, dtype: str,
tokenization_kwargs: dict[str, Any] | None = None,
) -> None: ) -> None:
if tokenization_kwargs is None:
tokenization_kwargs = {}
with vllm_runner( with vllm_runner(
model, model,
runner="pooling", runner="pooling",
@ -44,10 +50,14 @@ def _run_test(
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.embed(input_texts, images=input_images) vllm_outputs = vllm_model.embed(
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
)
with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model: with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
all_inputs = hf_model.get_inputs(input_texts, images=input_images) all_inputs = hf_model.get_inputs(
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
)
all_outputs = [] all_outputs = []
for inputs in all_inputs: for inputs in all_inputs:
@ -94,6 +104,10 @@ def test_models_text(
input_images, # type: ignore input_images, # type: ignore
model, model,
dtype=dtype, dtype=dtype,
tokenization_kwargs={
"padding": "max_length",
"max_length": 64,
}, # siglip2 was trained with this padding setting.
) )

View File

@ -1076,6 +1076,7 @@ class LLM:
params=pooling_params, params=pooling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
@ -1113,6 +1114,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EmbeddingRequestOutput]: ) -> list[EmbeddingRequestOutput]:
""" """
Generate an embedding vector for each prompt. Generate an embedding vector for each prompt.
@ -1150,6 +1152,7 @@ class LLM:
pooling_params=pooling_params, pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
pooling_task="embed", pooling_task="embed",
tokenization_kwargs=tokenization_kwargs,
) )
return [EmbeddingRequestOutput.from_base(item) for item in items] return [EmbeddingRequestOutput.from_base(item) for item in items]
@ -1161,6 +1164,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ClassificationRequestOutput]: ) -> list[ClassificationRequestOutput]:
""" """
Generate class logits for each prompt. Generate class logits for each prompt.
@ -1196,6 +1200,7 @@ class LLM:
pooling_params=pooling_params, pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
pooling_task="classify", pooling_task="classify",
tokenization_kwargs=tokenization_kwargs,
) )
return [ClassificationRequestOutput.from_base(item) for item in items] return [ClassificationRequestOutput.from_base(item) for item in items]
@ -1209,6 +1214,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
""" """
Generate rewards for each prompt. Generate rewards for each prompt.
@ -1236,6 +1242,7 @@ class LLM:
pooling_params=pooling_params, pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="token_classify", pooling_task="token_classify",
tokenization_kwargs=tokenization_kwargs,
) )
def _embedding_score( def _embedding_score(
@ -1247,6 +1254,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None, pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode( encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2, text_1 + text_2,
@ -1255,6 +1263,7 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
pooling_task="embed", pooling_task="embed",
tokenization_kwargs=tokenization_kwargs,
) )
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
@ -1279,6 +1288,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None, pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
model_config = self.model_config model_config = self.model_config
@ -1294,7 +1304,8 @@ class LLM:
pooling_params.verify("score", model_config) pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]() pooling_params_list = list[PoolingParams]()
tokenization_kwargs: dict[str, Any] = {} local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size( _validate_truncation_size(
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
@ -1557,6 +1568,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None, lora_request: Sequence[LoRARequest] | LoRARequest | None,
priority: list[int] | None = None, priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> None: ) -> None:
if isinstance(prompts, (str, dict)): if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
@ -1602,6 +1614,7 @@ class LLM:
if isinstance(lora_request, Sequence) if isinstance(lora_request, Sequence)
else lora_request, else lora_request,
priority=priority[i] if priority else 0, priority=priority[i] if priority else 0,
tokenization_kwargs=tokenization_kwargs,
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)
except Exception as e: except Exception as e:
@ -1665,9 +1678,12 @@ class LLM:
*, *,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
priority: int, priority: int,
tokenization_kwargs: dict[str, Any] | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]: ) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine.""" """Use the Processor to process inputs for LLMEngine."""
tokenization_kwargs: dict[str, Any] = {}
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size( _validate_truncation_size(
self.model_config.max_model_len, self.model_config.max_model_len,
params.truncate_prompt_tokens, params.truncate_prompt_tokens,
@ -1690,6 +1706,7 @@ class LLM:
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
) -> str: ) -> str:
prompt_text, _, _ = get_prompt_components(prompt) prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
@ -1700,6 +1717,7 @@ class LLM:
params, params,
lora_request=lora_request, lora_request=lora_request,
priority=priority, priority=priority,
tokenization_kwargs=tokenization_kwargs,
) )
self.llm_engine.add_request( self.llm_engine.add_request(