mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 01:57:01 +08:00
[Frontend] Consolidate tokenizer init code (#26276)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
77c95f72f7
commit
391612e78b
@ -7,7 +7,6 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.inputs import zip_enc_dec_prompts
|
from vllm.inputs import zip_enc_dec_prompts
|
||||||
from vllm.inputs.parse import parse_raw_prompts
|
from vllm.inputs.parse import parse_raw_prompts
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
@ -107,8 +106,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
|||||||
)
|
)
|
||||||
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
|
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
|
||||||
model_config = ModelConfig(model=model_id)
|
model_config = ModelConfig(model=model_id)
|
||||||
tokenizer = init_tokenizer_from_configs(model_config)
|
input_preprocessor = InputPreprocessor(model_config)
|
||||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="does not support multimodal inputs"):
|
with pytest.raises(ValueError, match="does not support multimodal inputs"):
|
||||||
input_preprocessor.preprocess(prompt)
|
input_preprocessor.preprocess(prompt)
|
||||||
@ -129,8 +127,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
|
|||||||
)
|
)
|
||||||
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
||||||
model_config = ModelConfig(model=model_id)
|
model_config = ModelConfig(model=model_id)
|
||||||
tokenizer = init_tokenizer_from_configs(model_config)
|
input_preprocessor = InputPreprocessor(model_config)
|
||||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
tokenizer = input_preprocessor.tokenizer
|
||||||
|
|
||||||
# HF processor adds sep token
|
# HF processor adds sep token
|
||||||
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
|
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
|
||||||
|
|||||||
@ -65,9 +65,7 @@ def _mk_processor(
|
|||||||
device_config=DeviceConfig(device="cpu"),
|
device_config=DeviceConfig(device="cpu"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pass tokenizer=None; InputPreprocessor handles None when
|
return Processor(vllm_config)
|
||||||
# skip_tokenizer_init is True.
|
|
||||||
return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
||||||
|
|||||||
@ -74,7 +74,6 @@ from vllm.transformers_utils.tokenizer import (
|
|||||||
AnyTokenizer,
|
AnyTokenizer,
|
||||||
MistralTokenizer,
|
MistralTokenizer,
|
||||||
get_cached_tokenizer,
|
get_cached_tokenizer,
|
||||||
init_tokenizer_from_configs,
|
|
||||||
)
|
)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||||
@ -367,11 +366,8 @@ class LLM:
|
|||||||
def _get_processor(self) -> Processor:
|
def _get_processor(self) -> Processor:
|
||||||
if not hasattr(self, "_processor"):
|
if not hasattr(self, "_processor"):
|
||||||
vllm_config = self.llm_engine.vllm_config
|
vllm_config = self.llm_engine.vllm_config
|
||||||
if self.model_config.skip_tokenizer_init:
|
self._processor = Processor(vllm_config)
|
||||||
tokenizer = None
|
|
||||||
else:
|
|
||||||
tokenizer = init_tokenizer_from_configs(self.model_config)
|
|
||||||
self._processor = Processor(vllm_config, tokenizer)
|
|
||||||
return self._processor
|
return self._processor
|
||||||
|
|
||||||
def get_default_sampling_params(self) -> SamplingParams:
|
def get_default_sampling_params(self) -> SamplingParams:
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from starlette.datastructures import Headers
|
|||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
|
|
||||||
@ -272,11 +271,8 @@ class OpenAIServing:
|
|||||||
async def _get_processor(self) -> Processor:
|
async def _get_processor(self) -> Processor:
|
||||||
if not hasattr(self, "_processor"):
|
if not hasattr(self, "_processor"):
|
||||||
vllm_config = await self.engine_client.get_vllm_config()
|
vllm_config = await self.engine_client.get_vllm_config()
|
||||||
if self.model_config.skip_tokenizer_init:
|
self._processor = Processor(vllm_config)
|
||||||
tokenizer = None
|
|
||||||
else:
|
|
||||||
tokenizer = init_tokenizer_from_configs(self.model_config)
|
|
||||||
self._processor = Processor(vllm_config, tokenizer)
|
|
||||||
return self._processor
|
return self._processor
|
||||||
|
|
||||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||||
|
|||||||
@ -17,7 +17,8 @@ from vllm.multimodal.inputs import (
|
|||||||
MultiModalUUIDDict,
|
MultiModalUUIDDict,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||||
|
from vllm.utils.jsontree import json_iter_leaves
|
||||||
|
|
||||||
from .data import (
|
from .data import (
|
||||||
DecoderOnlyInputs,
|
DecoderOnlyInputs,
|
||||||
@ -44,17 +45,20 @@ class InputPreprocessor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
tokenizer: Optional[AnyTokenizer],
|
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
|
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.mm_registry = mm_registry
|
self.mm_registry = mm_registry
|
||||||
self.mm_processor_cache = mm_processor_cache
|
self.mm_processor_cache = mm_processor_cache
|
||||||
|
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
|
self.tokenizer = None
|
||||||
|
else:
|
||||||
|
self.tokenizer = init_tokenizer_from_configs(model_config)
|
||||||
|
|
||||||
def get_tokenizer(self) -> AnyTokenizer:
|
def get_tokenizer(self) -> AnyTokenizer:
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -273,7 +277,10 @@ class InputPreprocessor:
|
|||||||
mm_hashes = mm_input["mm_hashes"]
|
mm_hashes = mm_input["mm_hashes"]
|
||||||
|
|
||||||
# Validate that all mm items have a string as their hash
|
# Validate that all mm items have a string as their hash
|
||||||
if not contains_only_strings(mm_hashes):
|
contains_only_strings = all(
|
||||||
|
isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
|
||||||
|
)
|
||||||
|
if not contains_only_strings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"mm_hashes must contain only strings, got: {mm_hashes}. "
|
f"mm_hashes must contain only strings, got: {mm_hashes}. "
|
||||||
"This is likely due to an incorrect custom implementation of "
|
"This is likely due to an incorrect custom implementation of "
|
||||||
@ -693,15 +700,3 @@ class InputPreprocessor:
|
|||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
if self.mm_processor_cache is not None:
|
if self.mm_processor_cache is not None:
|
||||||
self.mm_processor_cache.clear_cache()
|
self.mm_processor_cache.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
# Helper function to validate that a nested dictionary contains
|
|
||||||
# only strings or list of strings as the leaf values.
|
|
||||||
def contains_only_strings(obj: object):
|
|
||||||
if isinstance(obj, str):
|
|
||||||
return True
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return all(isinstance(x, str) for x in obj)
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return all(contains_only_strings(v) for v in obj.values())
|
|
||||||
return False
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tracing import init_tracer
|
from vllm.tracing import init_tracer
|
||||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
|
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
@ -104,20 +104,8 @@ class AsyncLLM(EngineClient):
|
|||||||
"logger list; enabling logging without default stat loggers"
|
"logger list; enabling logging without default stat loggers"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_config.skip_tokenizer_init:
|
|
||||||
self.tokenizer = None
|
|
||||||
else:
|
|
||||||
# Tokenizer (+ ensure liveness if running in another process).
|
|
||||||
self.tokenizer = init_tokenizer_from_configs(
|
|
||||||
model_config=vllm_config.model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Processor (converts Inputs --> EngineCoreRequests).
|
# Processor (converts Inputs --> EngineCoreRequests).
|
||||||
self.processor = Processor(
|
self.processor = Processor(vllm_config, mm_registry=mm_registry)
|
||||||
vllm_config=vllm_config,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
mm_registry=mm_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
@ -257,6 +245,10 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
cancel_task_threadsafe(getattr(self, "output_handler", None))
|
cancel_task_threadsafe(getattr(self, "output_handler", None))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||||
|
return self.processor.tokenizer
|
||||||
|
|
||||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return await self.engine_core.get_supported_tasks_async()
|
return await self.engine_core.get_supported_tasks_async()
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tracing import init_tracer
|
from vllm.tracing import init_tracer
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
@ -95,18 +95,8 @@ class LLMEngine:
|
|||||||
self.dp_group = None
|
self.dp_group = None
|
||||||
self.should_execute_dummy_batch = False
|
self.should_execute_dummy_batch = False
|
||||||
|
|
||||||
if self.model_config.skip_tokenizer_init:
|
|
||||||
self.tokenizer = None
|
|
||||||
else:
|
|
||||||
# Tokenizer (+ ensure liveness if running in another process).
|
|
||||||
self.tokenizer = init_tokenizer_from_configs(
|
|
||||||
model_config=vllm_config.model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Processor (convert Inputs --> EngineCoreRequests)
|
# Processor (convert Inputs --> EngineCoreRequests)
|
||||||
self.processor = Processor(
|
self.processor = Processor(vllm_config, mm_registry=mm_registry)
|
||||||
vllm_config=vllm_config, tokenizer=self.tokenizer, mm_registry=mm_registry
|
|
||||||
)
|
|
||||||
|
|
||||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
@ -214,6 +204,14 @@ class LLMEngine:
|
|||||||
def validate_outputs(cls, outputs, output_type):
|
def validate_outputs(cls, outputs, output_type):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||||
|
return self.processor.tokenizer
|
||||||
|
|
||||||
|
@tokenizer.setter
|
||||||
|
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
|
||||||
|
self.processor.tokenizer = tokenizer
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.engine_core.get_supported_tasks()
|
return self.engine_core.get_supported_tasks()
|
||||||
|
|
||||||
|
|||||||
@ -37,15 +37,13 @@ class Processor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
tokenizer: AnyTokenizer,
|
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
):
|
) -> None:
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.lora_config = vllm_config.lora_config
|
self.lora_config = vllm_config.lora_config
|
||||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
self.generation_config_fields = self.model_config.try_get_generation_config()
|
self.generation_config_fields = self.model_config.try_get_generation_config()
|
||||||
|
|
||||||
@ -54,11 +52,18 @@ class Processor:
|
|||||||
|
|
||||||
self.input_preprocessor = InputPreprocessor(
|
self.input_preprocessor = InputPreprocessor(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.tokenizer,
|
|
||||||
mm_registry,
|
mm_registry,
|
||||||
mm_processor_cache=self.mm_processor_cache,
|
mm_processor_cache=self.mm_processor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||||
|
return self.input_preprocessor.tokenizer
|
||||||
|
|
||||||
|
@tokenizer.setter
|
||||||
|
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
|
||||||
|
self.input_preprocessor.tokenizer = tokenizer
|
||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
@ -511,10 +516,8 @@ class Processor:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||||
|
|
||||||
if self.model_config.skip_tokenizer_init:
|
tokenizer = self.tokenizer
|
||||||
tokenizer = None
|
if tokenizer is not None:
|
||||||
else:
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
max_input_id = max(prompt_ids or [], default=0)
|
max_input_id = max(prompt_ids or [], default=0)
|
||||||
|
|
||||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user