mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:05:36 +08:00
Support tokenization_kwargs override (#29794)
Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
parent
c46b932df2
commit
43e7593031
@ -405,6 +405,7 @@ class HfRunner:
|
|||||||
images: PromptImageInput | None = None,
|
images: PromptImageInput | None = None,
|
||||||
videos: PromptVideoInput | None = None,
|
videos: PromptVideoInput | None = None,
|
||||||
audios: PromptAudioInput | None = None,
|
audios: PromptAudioInput | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
|
) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
@ -418,10 +419,18 @@ class HfRunner:
|
|||||||
all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
|
all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
processor_kwargs: dict[str, Any] = {
|
# Create a copy to avoid modifying the original dict
|
||||||
"text": prompt,
|
processor_kwargs = (
|
||||||
"return_tensors": "pt",
|
tokenization_kwargs.copy()
|
||||||
}
|
if tokenization_kwargs is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
processor_kwargs.update(
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
}
|
||||||
|
)
|
||||||
if images is not None and (image := images[i]) is not None:
|
if images is not None and (image := images[i]) is not None:
|
||||||
processor_kwargs["images"] = image
|
processor_kwargs["images"] = image
|
||||||
if videos is not None and (video := videos[i]) is not None:
|
if videos is not None and (video := videos[i]) is not None:
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import SiglipModel
|
from transformers import SiglipModel
|
||||||
|
|
||||||
@ -35,7 +37,11 @@ def _run_test(
|
|||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if tokenization_kwargs is None:
|
||||||
|
tokenization_kwargs = {}
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
@ -44,10 +50,14 @@ def _run_test(
|
|||||||
max_model_len=64,
|
max_model_len=64,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
|
vllm_outputs = vllm_model.embed(
|
||||||
|
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
|
with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
|
||||||
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
|
all_inputs = hf_model.get_inputs(
|
||||||
|
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
@ -94,6 +104,10 @@ def test_models_text(
|
|||||||
input_images, # type: ignore
|
input_images, # type: ignore
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
tokenization_kwargs={
|
||||||
|
"padding": "max_length",
|
||||||
|
"max_length": 64,
|
||||||
|
}, # siglip2 was trained with this padding setting.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1076,6 +1076,7 @@ class LLM:
|
|||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
@ -1113,6 +1114,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[EmbeddingRequestOutput]:
|
) -> list[EmbeddingRequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generate an embedding vector for each prompt.
|
Generate an embedding vector for each prompt.
|
||||||
@ -1150,6 +1152,7 @@ class LLM:
|
|||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
pooling_task="embed",
|
pooling_task="embed",
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||||
@ -1161,6 +1164,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[ClassificationRequestOutput]:
|
) -> list[ClassificationRequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generate class logits for each prompt.
|
Generate class logits for each prompt.
|
||||||
@ -1196,6 +1200,7 @@ class LLM:
|
|||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
pooling_task="classify",
|
pooling_task="classify",
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||||
@ -1209,6 +1214,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generate rewards for each prompt.
|
Generate rewards for each prompt.
|
||||||
@ -1236,6 +1242,7 @@ class LLM:
|
|||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
pooling_task="token_classify",
|
pooling_task="token_classify",
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _embedding_score(
|
def _embedding_score(
|
||||||
@ -1247,6 +1254,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
pooling_params: PoolingParams | None = None,
|
pooling_params: PoolingParams | None = None,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[ScoringRequestOutput]:
|
) -> list[ScoringRequestOutput]:
|
||||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||||
text_1 + text_2,
|
text_1 + text_2,
|
||||||
@ -1255,6 +1263,7 @@ class LLM:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
pooling_task="embed",
|
pooling_task="embed",
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||||
@ -1279,6 +1288,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
pooling_params: PoolingParams | None = None,
|
pooling_params: PoolingParams | None = None,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[ScoringRequestOutput]:
|
) -> list[ScoringRequestOutput]:
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
@ -1294,7 +1304,8 @@ class LLM:
|
|||||||
pooling_params.verify("score", model_config)
|
pooling_params.verify("score", model_config)
|
||||||
pooling_params_list = list[PoolingParams]()
|
pooling_params_list = list[PoolingParams]()
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
local_kwargs = tokenization_kwargs or {}
|
||||||
|
tokenization_kwargs = local_kwargs.copy()
|
||||||
|
|
||||||
_validate_truncation_size(
|
_validate_truncation_size(
|
||||||
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
|
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
|
||||||
@ -1557,6 +1568,7 @@ class LLM:
|
|||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
lora_request: Sequence[LoRARequest] | LoRARequest | None,
|
lora_request: Sequence[LoRARequest] | LoRARequest | None,
|
||||||
priority: list[int] | None = None,
|
priority: list[int] | None = None,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(prompts, (str, dict)):
|
if isinstance(prompts, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
@ -1602,6 +1614,7 @@ class LLM:
|
|||||||
if isinstance(lora_request, Sequence)
|
if isinstance(lora_request, Sequence)
|
||||||
else lora_request,
|
else lora_request,
|
||||||
priority=priority[i] if priority else 0,
|
priority=priority[i] if priority else 0,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
added_request_ids.append(request_id)
|
added_request_ids.append(request_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1665,9 +1678,12 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
lora_request: LoRARequest | None,
|
lora_request: LoRARequest | None,
|
||||||
priority: int,
|
priority: int,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||||
"""Use the Processor to process inputs for LLMEngine."""
|
"""Use the Processor to process inputs for LLMEngine."""
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
|
||||||
|
local_kwargs = tokenization_kwargs or {}
|
||||||
|
tokenization_kwargs = local_kwargs.copy()
|
||||||
_validate_truncation_size(
|
_validate_truncation_size(
|
||||||
self.model_config.max_model_len,
|
self.model_config.max_model_len,
|
||||||
params.truncate_prompt_tokens,
|
params.truncate_prompt_tokens,
|
||||||
@ -1690,6 +1706,7 @@ class LLM:
|
|||||||
params: SamplingParams | PoolingParams,
|
params: SamplingParams | PoolingParams,
|
||||||
lora_request: LoRARequest | None = None,
|
lora_request: LoRARequest | None = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
prompt_text, _, _ = get_prompt_components(prompt)
|
prompt_text, _, _ = get_prompt_components(prompt)
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
@ -1700,6 +1717,7 @@ class LLM:
|
|||||||
params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_engine.add_request(
|
self.llm_engine.add_request(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user