Simplify TokenizerGroup (#16790)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-24 12:43:56 +01:00 committed by GitHub
parent 14288d1332
commit 0a05ed57e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 80 additions and 752 deletions

View File

@ -23,7 +23,7 @@ from tests.models.utils import (TokensTextLogprobs,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
@ -1010,20 +1010,6 @@ def vllm_runner():
return VllmRunner
def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
if isinstance(tokenizer_group_type, type):
return TokenizerPoolConfig(pool_size=1,
pool_type=tokenizer_group_type,
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
@pytest.fixture()
def temporary_enable_log_propagate():
import logging

View File

@ -5,17 +5,14 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from ..conftest import get_tokenizer_pool_config
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
@pytest.mark.parametrize("max_num_seqs", [1, 2])
@pytest.mark.parametrize("max_loras", [1, 2])
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(None),
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,

View File

@ -10,7 +10,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
@ -212,7 +212,7 @@ def test_oov_decode(tokenizer, fast):
@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
init_kwargs = dict(
tokenizer_group = TokenizerGroup(
tokenizer_id=tokenizer_name,
enable_lora=False,
max_num_seqs=100,
@ -222,11 +222,6 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
revision=None,
)
tokenizer_group = get_tokenizer_group(
None,
**init_kwargs,
)
return Detokenizer(tokenizer_group)

View File

@ -1,40 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import sys
from typing import Optional
from unittest.mock import patch
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from ..conftest import get_tokenizer_pool_config
class CustomTokenizerGroup(TokenizerGroup):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._i = 0
def encode(self, *args, **kwargs):
self._i += 1
return super().encode(*args, **kwargs)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type",
[None, "ray", CustomTokenizerGroup])
async def test_tokenizer_group(tokenizer_group_type):
async def test_tokenizer_group():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
if tokenizer_group_type is CustomTokenizerGroup:
assert tokenizer_group._i > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group_pool = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5
requests = [
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
lora_request=None)
for i in range(num_requests)
]
results = await asyncio.gather(*requests)
expected_results = [
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
]
assert results == expected_results
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
tokenizer_group_type):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var = "MY_ENV_VAR"
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
def ping(self):
assert os.environ.get(env_var) == "1"
return super().ping()
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = EnvVarCheckerTokenizerGroup
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
with pytest.raises(AssertionError):
tokenizer_pool.ping()
with patch.dict(os.environ, {env_var: "1"}):
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class FailingTokenizerGroup(TokenizerGroup):
def __init__(self,
*args,
fail_at: Optional[list[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
self.fail_at = fail_at or []
def encode(self, *args, **kwargs):
self.i += 1
if self.i in self.fail_at:
sys.exit(1)
return super().encode(*args, **kwargs)
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = FailingTokenizerGroup
# Fail at first iteration
fail_at = [1]
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at[0] = 1000
# We should recover successfully.
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
# Fail at first iteration
fail_at = [1]
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(prompt="prompt",
lora_request=None)
# check_health should raise the same thing
with pytest.raises(RuntimeError):
tokenizer_group_pool.check_health()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at = []
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=2,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
lora_request=None)
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors

View File

@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
tokenizer=tokenizer,
tokenizer_group=init_tokenizer_from_configs(
vllm_config.model_config, vllm_config.scheduler_config,
vllm_config.parallel_config, vllm_config.lora_config),
vllm_config.lora_config),
vllm_config=vllm_config,
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
prompt_tokens=prompt_tokens,

View File

@ -8,8 +8,7 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
tokenizer_group: BaseTokenizerGroup
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]]

View File

@ -52,8 +52,6 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
ConfigType = type[DataclassInstance]
else:
@ -1407,83 +1405,33 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
PoolType = Literal["ray"]
@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool."""
"""This config is deprecated and will be removed in a future release.
Passing these parameters will have no effect. Please remove them from your
configurations.
"""
pool_size: int = 0
"""Number of tokenizer workers in the pool to use for asynchronous
tokenization. If 0, will use synchronous tokenization."""
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
tokenizer_pool_size is 0."""
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
pool_type: str = "ray"
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
extra_config: dict = field(default_factory=dict)
"""Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
def __post_init__(self) -> None:
logger.warning_once(
"TokenizerPoolConfig is deprecated and will be removed in a "
"future release. Passing this parameter will have no effect. "
"Please remove it from your configurations.")
class LoadFormat(str, enum.Enum):
@ -1624,8 +1572,8 @@ class ParallelConfig:
"""Disable the custom all-reduce kernel and fall back to NCCL."""
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
"""This parameter is deprecated and will be removed in a future release.
Please remove it from your configs"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
@ -2544,7 +2492,6 @@ class SpeculativeConfig:
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,

View File

@ -7,9 +7,8 @@ import json
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin)
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs
@ -23,7 +22,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PoolType, PrefixCachingHashAlgo, PromptAdapterConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
@ -39,9 +38,6 @@ from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
# yapf: enable
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
@ -185,13 +181,12 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
# The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your
# configurations.
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
@ -1187,11 +1182,6 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,

View File

@ -526,8 +526,6 @@ class _AsyncLLMEngine(LLMEngine):
)
async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()

View File

@ -55,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (Counter, Device, deprecate_kwargs,
@ -66,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any)
@ -205,7 +204,7 @@ class LLMEngine:
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
tokenizer: Optional[TokenizerGroup]
def __init__(
self,
@ -321,11 +320,6 @@ class LLMEngine:
self.parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
@ -537,21 +531,12 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
return self.tokenizer
def get_tokenizer(
self,
@ -559,11 +544,10 @@ class LLMEngine:
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup:
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
lora_config=self.lora_config)
def _verify_args(self) -> None:
@ -1952,8 +1936,6 @@ class LLMEngine:
return self.model_executor.is_sleeping
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
def is_tracing_enabled(self) -> bool:

View File

@ -101,7 +101,6 @@ class MQLLMEngineClient(EngineClient):
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)

View File

@ -40,7 +40,6 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
@ -253,10 +252,10 @@ class LLM:
self.default_sampling_params: Union[dict[str, Any], None] = None
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
return self.llm_engine.get_tokenizer_group().tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from

View File

@ -13,7 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
@ -27,7 +27,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup],
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
super().__init__()
@ -36,7 +36,7 @@ class InputPreprocessor:
self.tokenizer = tokenizer
self.mm_registry = mm_registry
def get_tokenizer_group(self) -> BaseTokenizerGroup:
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")

View File

@ -8,13 +8,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: BaseTokenizerGroup):
def __init__(self, tokenizer_group: TokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:

View File

@ -2,7 +2,7 @@
from typing import List, Optional
from vllm.config import TokenizerPoolConfig
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
@ -10,10 +10,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_tokenizer)
from vllm.utils import LRUCache
from .base_tokenizer_group import BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup):
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
@ -27,15 +25,6 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "TokenizerGroup":
return cls(**init_kwargs)
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
@ -104,3 +93,18 @@ class TokenizerGroup(BaseTokenizerGroup):
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig]):
return TokenizerGroup(
tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=model_config.truncation_side)

View File

@ -1,56 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Type
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from .ray_tokenizer_group import RayTokenizerGroupPool
else:
RayTokenizerGroupPool = None # type: ignore
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
lora_config: Optional[LoRAConfig]):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=model_config.truncation_side)
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
if tokenizer_pool_config is None:
tokenizer_cls = TokenizerGroup
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
tokenizer_cls = tokenizer_pool_config.pool_type
elif tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
tokenizer_cls = RayTokenizerGroupPool
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]

View File

@ -1,68 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import List, Optional
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
@classmethod
@abstractmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "BaseTokenizerGroup":
pass
@abstractmethod
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
pass
@abstractmethod
def get_max_input_len(
self,
lora_request: Optional[LoRARequest] = None,
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
pass
@abstractmethod
def encode(self,
prompt: str,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
async def encode_async(
self,
prompt: str,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
def check_health(self):
"""Raise exception if the tokenizer group is unhealthy."""
return

View File

@ -1,244 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import List, Optional
try:
from ray.exceptions import ActorDiedError # type: ignore
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .base_tokenizer_group import BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
logger = init_logger(__name__)
class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
# Class to use for workers making up the pool.
_worker_cls = TokenizerGroup
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "RayTokenizerGroupPool":
if not tokenizer_pool_config:
raise ValueError("tokenizer_pool_config must not be None.")
ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0
})
ray_actor_options.setdefault(
"scheduling_strategy",
NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=True))
# Carry over the env vars to the actors.
# This is necessary for API keys and such.
ray_actor_options.setdefault("runtime_env", {})
_carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
init_kwargs["ray_actor_options"] = ray_actor_options
return cls(**init_kwargs)
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], num_actors: int,
ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self._tokenizer_config = {
"tokenizer_id": tokenizer_id,
"enable_lora": enable_lora,
"max_num_seqs": max_num_seqs,
"max_input_length": max_input_length,
**tokenizer_config
}
self._local_tokenizer_group = self._worker_cls(
**self._tokenizer_config, )
self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options) # type: ignore
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None
# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self._exception: Optional[ActorDiedError] = None
def _init_actor(self) -> ray.ObjectRef:
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
@property
def pool_size(self) -> int:
return len(self.tokenizer_actors)
def ping(self):
return ray.get([
actor.ping.remote() # type: ignore
for actor in self.tokenizer_actors
])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
self._idle_actors = asyncio.Queue()
for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor)
def _finalize_encode(self, actor: ray.ObjectRef,
original_actor: ray.ObjectRef, actor_is_alive: bool):
assert self._idle_actors is not None
# Cleanup the dead actor.
if not actor_is_alive or original_actor is not actor:
self.tokenizer_actors.remove(original_actor)
if actor_is_alive:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
# Add back the new actor.
if original_actor is not actor:
self.tokenizer_actors.append(actor)
def encode(self,
prompt: str,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
The actor is then put back in the queue for future use.
This is blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None
if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait()
actor_is_alive = True
original_actor = actor
try:
ret = ray.get(
actor.encode.remote(prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = ray.get(
actor.encode.remote(prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret
async def encode_async(
self,
prompt: str,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
If there are no idle actors, we wait until one becomes
available.
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None
actor = await self._idle_actors.get()
actor_is_alive = True
original_actor = actor
try:
ret = await actor.encode.remote(
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)
def check_health(self):
if self._exception:
raise RuntimeError(
"TokenizerGroupPool is unhealthy.") from self._exception
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env.
The variables in runtime_env will take precedence over the current process
environment variables.
runtime_env will be modified in place."""
env_vars = os.environ.copy()
runtime_env.setdefault("env_vars", {})
env_vars.update(runtime_env["env_vars"])
runtime_env["env_vars"] = env_vars

View File

@ -81,9 +81,7 @@ class AsyncLLM(EngineClient):
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(

View File

@ -20,7 +20,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
@ -32,7 +32,6 @@ from vllm.v1.utils import report_usage_stats
logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
@ -74,9 +73,7 @@ class LLMEngine:
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
@ -258,21 +255,12 @@ class LLMEngine:
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_tokenizer_group(
self,
group_type: type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
return self.tokenizer
def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""

View File

@ -8,7 +8,7 @@ from typing import Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
@ -225,7 +225,7 @@ class OutputProcessor:
def __init__(
self,
tokenizer: BaseTokenizerGroup,
tokenizer: TokenizerGroup,
log_stats: bool,
):
self.log_stats = log_stats

View File

@ -17,7 +17,7 @@ from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
@ -31,7 +31,7 @@ class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup,
tokenizer: TokenizerGroup,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):

View File

@ -61,9 +61,7 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()

View File

@ -35,9 +35,7 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.disable_any_whitespace = False
backend_options = GuidedDecodingParams(