mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 02:07:15 +08:00
Simplify TokenizerGroup (#16790)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
14288d1332
commit
0a05ed57e6
@ -23,7 +23,7 @@ from tests.models.utils import (TokensTextLogprobs,
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype
|
||||
from vllm.config import TaskOption, _get_and_verify_dtype
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
@ -1010,20 +1010,6 @@ def vllm_runner():
|
||||
return VllmRunner
|
||||
|
||||
|
||||
def get_tokenizer_pool_config(tokenizer_group_type):
|
||||
if tokenizer_group_type is None:
|
||||
return None
|
||||
if tokenizer_group_type == "ray":
|
||||
return TokenizerPoolConfig(pool_size=1,
|
||||
pool_type="ray",
|
||||
extra_config={})
|
||||
if isinstance(tokenizer_group_type, type):
|
||||
return TokenizerPoolConfig(pool_size=1,
|
||||
pool_type=tokenizer_group_type,
|
||||
extra_config={})
|
||||
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def temporary_enable_log_propagate():
|
||||
import logging
|
||||
|
||||
@ -5,17 +5,14 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
||||
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=True,
|
||||
max_num_seqs=1,
|
||||
@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
|
||||
@pytest.mark.parametrize("max_num_seqs", [1, 2])
|
||||
@pytest.mark.parametrize("max_loras", [1, 2])
|
||||
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(None),
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=max_num_seqs,
|
||||
|
||||
@ -10,7 +10,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
@ -212,7 +212,7 @@ def test_oov_decode(tokenizer, fast):
|
||||
|
||||
@pytest.fixture
|
||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
init_kwargs = dict(
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id=tokenizer_name,
|
||||
enable_lora=False,
|
||||
max_num_seqs=100,
|
||||
@ -222,11 +222,6 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
revision=None,
|
||||
)
|
||||
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
None,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
return Detokenizer(tokenizer_group)
|
||||
|
||||
|
||||
|
||||
@ -1,40 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
|
||||
|
||||
class CustomTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._i = 0
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self._i += 1
|
||||
return super().encode(*args, **kwargs)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type",
|
||||
[None, "ray", CustomTokenizerGroup])
|
||||
async def test_tokenizer_group(tokenizer_group_type):
|
||||
async def test_tokenizer_group():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||
if tokenizer_group_type is CustomTokenizerGroup:
|
||||
assert tokenizer_group._i > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_pool(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group_pool = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
# Send multiple requests to the tokenizer group pool
|
||||
# (more than the pool size)
|
||||
# and check that all requests are processed correctly.
|
||||
num_requests = tokenizer_group_pool.pool_size * 5
|
||||
requests = [
|
||||
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
|
||||
lora_request=None)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
results = await asyncio.gather(*requests)
|
||||
expected_results = [
|
||||
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
|
||||
]
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_env_var_propagation(
|
||||
tokenizer_group_type):
|
||||
"""Test that env vars from caller process are propagated to
|
||||
tokenizer Ray actors."""
|
||||
env_var = "MY_ENV_VAR"
|
||||
|
||||
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def ping(self):
|
||||
assert os.environ.get(env_var) == "1"
|
||||
return super().ping()
|
||||
|
||||
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = EnvVarCheckerTokenizerGroup
|
||||
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
with pytest.raises(AssertionError):
|
||||
tokenizer_pool.ping()
|
||||
|
||||
with patch.dict(os.environ, {env_var: "1"}):
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
tokenizer_pool.ping()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
"""Test that Ray tokenizer pool group can recover from failures and
|
||||
if that's not possible, mark itself as unhealthy."""
|
||||
|
||||
class FailingTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
fail_at: Optional[list[int]] = None,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.i = 0
|
||||
self.fail_at = fail_at or []
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self.i += 1
|
||||
if self.i in self.fail_at:
|
||||
sys.exit(1)
|
||||
return super().encode(*args, **kwargs)
|
||||
|
||||
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = FailingTokenizerGroup
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Modify fail at to not fail at all (will be re-read when actor is
|
||||
# re-initialized).
|
||||
fail_at[0] = 1000
|
||||
|
||||
# We should recover successfully.
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
|
||||
# Check that we have a new actor
|
||||
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
|
||||
# We should fail after re-initialization.
|
||||
with pytest.raises(RuntimeError):
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# check_health should raise the same thing
|
||||
with pytest.raises(RuntimeError):
|
||||
tokenizer_group_pool.check_health()
|
||||
|
||||
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
|
||||
# cause a re-initialization.
|
||||
fail_at = []
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=2,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Prompt too long error
|
||||
with pytest.raises(ValueError):
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
# Actors should stay the same.
|
||||
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||
|
||||
@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_group=init_tokenizer_from_configs(
|
||||
vllm_config.model_config, vllm_config.scheduler_config,
|
||||
vllm_config.parallel_config, vllm_config.lora_config),
|
||||
vllm_config.lora_config),
|
||||
vllm_config=vllm_config,
|
||||
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@ -8,8 +8,7 @@ import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
|
||||
class DummyOutputProcessorTestVectors:
|
||||
"""Dummy test vectors for output processor tests"""
|
||||
tokenizer: GeneralTokenizerType
|
||||
tokenizer_group: BaseTokenizerGroup
|
||||
tokenizer_group: TokenizerGroup
|
||||
vllm_config: EngineArgs
|
||||
full_tokens: list[list[int]] # Prompt + generated tokens
|
||||
prompt_tokens: list[list[int]]
|
||||
|
||||
@ -52,8 +52,6 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
else:
|
||||
@ -1407,83 +1405,33 @@ class CacheConfig:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
|
||||
|
||||
PoolType = Literal["ray"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class TokenizerPoolConfig:
|
||||
"""Configuration for the tokenizer pool."""
|
||||
"""This config is deprecated and will be removed in a future release.
|
||||
|
||||
Passing these parameters will have no effect. Please remove them from your
|
||||
configurations.
|
||||
"""
|
||||
|
||||
pool_size: int = 0
|
||||
"""Number of tokenizer workers in the pool to use for asynchronous
|
||||
tokenization. If 0, will use synchronous tokenization."""
|
||||
|
||||
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
|
||||
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
|
||||
tokenizer_pool_size is 0."""
|
||||
|
||||
"""This parameter is deprecated and will be removed in a future release.
|
||||
Passing this parameter will have no effect. Please remove it from your
|
||||
configurations."""
|
||||
pool_type: str = "ray"
|
||||
"""This parameter is deprecated and will be removed in a future release.
|
||||
Passing this parameter will have no effect. Please remove it from your
|
||||
configurations."""
|
||||
extra_config: dict = field(default_factory=dict)
|
||||
"""Additional config for the pool. The way the config will be used depends
|
||||
on the pool type. This should be a JSON string that will be parsed into a
|
||||
dictionary. Ignored if tokenizer_pool_size is 0."""
|
||||
"""This parameter is deprecated and will be removed in a future release.
|
||||
Passing this parameter will have no effect. Please remove it from your
|
||||
configurations."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pool_type not in ("ray", ) and not isinstance(
|
||||
self.pool_type, type):
|
||||
raise ValueError(f"Unknown pool type: {self.pool_type}")
|
||||
if not isinstance(self.extra_config, dict):
|
||||
raise ValueError("extra_config must be a dictionary.")
|
||||
|
||||
@classmethod
|
||||
def create_config(
|
||||
cls, tokenizer_pool_size: int,
|
||||
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
|
||||
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
||||
) -> Optional["TokenizerPoolConfig"]:
|
||||
"""Create a TokenizerPoolConfig from the given parameters.
|
||||
|
||||
If tokenizer_pool_size is 0, return None.
|
||||
|
||||
Args:
|
||||
tokenizer_pool_size: Number of tokenizer workers in the pool.
|
||||
tokenizer_pool_type: Type of the pool.
|
||||
tokenizer_pool_extra_config: Additional config for the pool.
|
||||
The way the config will be used depends on the
|
||||
pool type. This can be a JSON string (will be parsed).
|
||||
"""
|
||||
if tokenizer_pool_size:
|
||||
if isinstance(tokenizer_pool_extra_config, str):
|
||||
tokenizer_pool_extra_config_parsed = json.loads(
|
||||
tokenizer_pool_extra_config)
|
||||
else:
|
||||
tokenizer_pool_extra_config_parsed = (
|
||||
tokenizer_pool_extra_config or {})
|
||||
tokenizer_pool_config = cls(tokenizer_pool_size,
|
||||
tokenizer_pool_type,
|
||||
tokenizer_pool_extra_config_parsed)
|
||||
else:
|
||||
tokenizer_pool_config = None
|
||||
return tokenizer_pool_config
|
||||
def __post_init__(self) -> None:
|
||||
logger.warning_once(
|
||||
"TokenizerPoolConfig is deprecated and will be removed in a "
|
||||
"future release. Passing this parameter will have no effect. "
|
||||
"Please remove it from your configurations.")
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
@ -1624,8 +1572,8 @@ class ParallelConfig:
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
||||
"""Config for the tokenizer pool. If None, will use synchronous
|
||||
tokenization."""
|
||||
"""This parameter is deprecated and will be removed in a future release.
|
||||
Please remove it from your configs"""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
@ -2544,7 +2492,6 @@ class SpeculativeConfig:
|
||||
max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.
|
||||
disable_custom_all_reduce,
|
||||
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
|
||||
ray_workers_use_nsight=target_parallel_config.
|
||||
ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
|
||||
@ -7,9 +7,8 @@ import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
|
||||
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
|
||||
get_origin)
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs
|
||||
@ -23,7 +22,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PoolType, PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerPoolConfig, VllmConfig,
|
||||
get_attr_docs, get_field)
|
||||
@ -39,9 +38,6 @@ from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
|
||||
@ -185,13 +181,12 @@ class EngineArgs:
|
||||
enforce_eager: Optional[bool] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
# The following three fields are deprecated and will be removed in a future
|
||||
# release. Setting them will have no effect. Please remove them from your
|
||||
# configurations.
|
||||
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
|
||||
# Note: Specifying a tokenizer pool by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
|
||||
TokenizerPoolConfig.pool_type
|
||||
tokenizer_pool_extra_config: dict[str, Any] = \
|
||||
tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
|
||||
tokenizer_pool_extra_config: dict = \
|
||||
get_field(TokenizerPoolConfig, "extra_config")
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
@ -1187,11 +1182,6 @@ class EngineArgs:
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
tokenizer_pool_config=TokenizerPoolConfig.create_config(
|
||||
self.tokenizer_pool_size,
|
||||
self.tokenizer_pool_type,
|
||||
self.tokenizer_pool_extra_config,
|
||||
),
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
placement_group=placement_group,
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
|
||||
@ -526,8 +526,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
)
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (Counter, Device, deprecate_kwargs,
|
||||
@ -66,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
@ -205,7 +204,7 @@ class LLMEngine:
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[BaseTokenizerGroup]
|
||||
tokenizer: Optional[TokenizerGroup]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -321,11 +320,6 @@ class LLMEngine:
|
||||
self.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
if self.tokenizer:
|
||||
# Ping the tokenizer to ensure liveness if it runs in a
|
||||
# different process.
|
||||
self.tokenizer.ping()
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
@ -537,21 +531,12 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
group_type: Type[_G] = BaseTokenizerGroup,
|
||||
) -> _G:
|
||||
tokenizer_group = self.tokenizer
|
||||
|
||||
if tokenizer_group is None:
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
if not isinstance(tokenizer_group, group_type):
|
||||
raise TypeError("Invalid type of tokenizer group. "
|
||||
f"Expected type: {group_type}, but "
|
||||
f"found type: {type(tokenizer_group)}")
|
||||
|
||||
return tokenizer_group
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
@ -559,11 +544,10 @@ class LLMEngine:
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
def _init_tokenizer(self) -> TokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
lora_config=self.lora_config)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@ -1952,8 +1936,6 @@ class LLMEngine:
|
||||
return self.model_executor.is_sleeping
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
|
||||
@ -101,7 +101,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
lora_config=engine_config.lora_config)
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer)
|
||||
|
||||
@ -40,7 +40,6 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
|
||||
is_list_of)
|
||||
@ -253,10 +252,10 @@ class LLM:
|
||||
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
|
||||
return self.llm_engine.get_tokenizer_group().tokenizer
|
||||
|
||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
|
||||
tokenizer_group = self.llm_engine.get_tokenizer_group()
|
||||
|
||||
# While CachedTokenizer is dynamic, have no choice but
|
||||
# compare class name. Misjudgment will arise from
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
@ -27,7 +27,7 @@ class InputPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[BaseTokenizerGroup],
|
||||
tokenizer: Optional[TokenizerGroup],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -36,7 +36,7 @@ class InputPreprocessor:
|
||||
self.tokenizer = tokenizer
|
||||
self.mm_registry = mm_registry
|
||||
|
||||
def get_tokenizer_group(self) -> BaseTokenizerGroup:
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("You cannot pass text prompts when "
|
||||
"`skip_tokenizer_init` is True")
|
||||
|
||||
@ -8,13 +8,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
||||
def __init__(self, tokenizer_group: TokenizerGroup):
|
||||
self.tokenizer_group = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
|
||||
get_lora_tokenizer,
|
||||
@ -10,10 +10,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
|
||||
get_tokenizer)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
from .base_tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
|
||||
class TokenizerGroup(BaseTokenizerGroup):
|
||||
class TokenizerGroup:
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
@ -27,15 +25,6 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> "TokenizerGroup":
|
||||
return cls(**init_kwargs)
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Check if the tokenizer group is alive."""
|
||||
return True
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
@ -104,3 +93,18 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig]):
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=bool(lora_config),
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_loras=lora_config.max_loras if lora_config else 0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=model_config.truncation_side)
|
||||
@ -1,56 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, TokenizerPoolConfig)
|
||||
from vllm.executor.ray_utils import ray
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
if ray:
|
||||
from .ray_tokenizer_group import RayTokenizerGroupPool
|
||||
else:
|
||||
RayTokenizerGroupPool = None # type: ignore
|
||||
|
||||
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
lora_config: Optional[LoRAConfig]):
|
||||
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=bool(lora_config),
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_loras=lora_config.max_loras if lora_config else 0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=model_config.truncation_side)
|
||||
|
||||
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
|
||||
|
||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> BaseTokenizerGroup:
|
||||
tokenizer_cls: Type[BaseTokenizerGroup]
|
||||
if tokenizer_pool_config is None:
|
||||
tokenizer_cls = TokenizerGroup
|
||||
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
|
||||
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
|
||||
tokenizer_cls = tokenizer_pool_config.pool_type
|
||||
elif tokenizer_pool_config.pool_type == "ray":
|
||||
if RayTokenizerGroupPool is None:
|
||||
raise ImportError(
|
||||
"RayTokenizerGroupPool is not available. Please install "
|
||||
"the ray package to use the Ray tokenizer group pool.")
|
||||
tokenizer_cls = RayTokenizerGroupPool
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
|
||||
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
|
||||
|
||||
|
||||
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]
|
||||
@ -1,68 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class BaseTokenizerGroup(ABC):
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> "BaseTokenizerGroup":
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ping(self) -> bool:
|
||||
"""Check if the tokenizer group is alive."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_input_len(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
def check_health(self):
|
||||
"""Raise exception if the tokenizer group is unhealthy."""
|
||||
return
|
||||
@ -1,244 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from ray.exceptions import ActorDiedError # type: ignore
|
||||
except ImportError:
|
||||
# For older versions of Ray
|
||||
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
|
||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .base_tokenizer_group import BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
"""A Ray-based pool of TokenizerGroups for async tokenization."""
|
||||
|
||||
# Class to use for workers making up the pool.
|
||||
_worker_cls = TokenizerGroup
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> "RayTokenizerGroupPool":
|
||||
if not tokenizer_pool_config:
|
||||
raise ValueError("tokenizer_pool_config must not be None.")
|
||||
ray_actor_options = (tokenizer_pool_config.extra_config or {
|
||||
"num_cpus": 0
|
||||
})
|
||||
ray_actor_options.setdefault(
|
||||
"scheduling_strategy",
|
||||
NodeAffinitySchedulingStrategy(
|
||||
node_id=ray.get_runtime_context().get_node_id(), soft=True))
|
||||
|
||||
# Carry over the env vars to the actors.
|
||||
# This is necessary for API keys and such.
|
||||
ray_actor_options.setdefault("runtime_env", {})
|
||||
_carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
|
||||
|
||||
init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
|
||||
init_kwargs["ray_actor_options"] = ray_actor_options
|
||||
|
||||
return cls(**init_kwargs)
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
max_input_length: Optional[int], num_actors: int,
|
||||
ray_actor_options: dict, **tokenizer_config):
|
||||
# Store a local copy of the TokenizerGroup for quick access
|
||||
# to underlying HF tokenizers.
|
||||
self._tokenizer_config = {
|
||||
"tokenizer_id": tokenizer_id,
|
||||
"enable_lora": enable_lora,
|
||||
"max_num_seqs": max_num_seqs,
|
||||
"max_input_length": max_input_length,
|
||||
**tokenizer_config
|
||||
}
|
||||
self._local_tokenizer_group = self._worker_cls(
|
||||
**self._tokenizer_config, )
|
||||
|
||||
self._ray_tokenizer_group_cls = ray.remote(
|
||||
self._worker_cls).options(**ray_actor_options) # type: ignore
|
||||
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
|
||||
self._idle_actors: Optional[asyncio.Queue] = None
|
||||
|
||||
# If set, actor is unhealthy. Will reraise on the next
|
||||
# check_health call.
|
||||
self._exception: Optional[ActorDiedError] = None
|
||||
|
||||
def _init_actor(self) -> ray.ObjectRef:
|
||||
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
|
||||
|
||||
@property
|
||||
def pool_size(self) -> int:
|
||||
return len(self.tokenizer_actors)
|
||||
|
||||
def ping(self):
|
||||
return ray.get([
|
||||
actor.ping.remote() # type: ignore
|
||||
for actor in self.tokenizer_actors
|
||||
])
|
||||
|
||||
def _ensure_queue_initialized(self):
|
||||
if self._idle_actors is None:
|
||||
self._idle_actors = asyncio.Queue()
|
||||
for actor in self.tokenizer_actors:
|
||||
self._idle_actors.put_nowait(actor)
|
||||
|
||||
def _finalize_encode(self, actor: ray.ObjectRef,
|
||||
original_actor: ray.ObjectRef, actor_is_alive: bool):
|
||||
assert self._idle_actors is not None
|
||||
# Cleanup the dead actor.
|
||||
if not actor_is_alive or original_actor is not actor:
|
||||
self.tokenizer_actors.remove(original_actor)
|
||||
if actor_is_alive:
|
||||
# Put the actor back in the queue.
|
||||
# This is done in a finally block to ensure that the actor is
|
||||
# always put back in the queue, even if an exception/cancellation
|
||||
# is raised.
|
||||
self._idle_actors.put_nowait(actor)
|
||||
# Add back the new actor.
|
||||
if original_actor is not actor:
|
||||
self.tokenizer_actors.append(actor)
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
|
||||
We pick an idle actor and use it to encode the prompt.
|
||||
The actor is then put back in the queue for future use.
|
||||
This is blocking.
|
||||
"""
|
||||
self.check_health()
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
if self._idle_actors.empty():
|
||||
raise RuntimeError("No idle actors available.")
|
||||
actor = self._idle_actors.get_nowait()
|
||||
actor_is_alive = True
|
||||
original_actor = actor
|
||||
try:
|
||||
ret = ray.get(
|
||||
actor.encode.remote(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens))
|
||||
except ActorDiedError as e:
|
||||
# If the actor is dead, we first try to reinitialize it.
|
||||
logger.warning("%s died with ActorDiedError, reinitializing.",
|
||||
actor,
|
||||
exc_info=e)
|
||||
actor = self._init_actor()
|
||||
try:
|
||||
ret = ray.get(
|
||||
actor.encode.remote(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens))
|
||||
except ActorDiedError as e:
|
||||
logger.error(
|
||||
"%s died for second time in a row, marking "
|
||||
"RayTokenizerGroupPool as unhealthy.", actor)
|
||||
actor_is_alive = False
|
||||
if not self._exception:
|
||||
self._exception = e
|
||||
self.check_health()
|
||||
finally:
|
||||
self._finalize_encode(actor, original_actor, actor_is_alive)
|
||||
return ret
|
||||
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
|
||||
We pick an idle actor and use it to encode the prompt.
|
||||
If there are no idle actors, we wait until one becomes
|
||||
available.
|
||||
The actor is then put back in the queue for future use.
|
||||
This is non-blocking.
|
||||
"""
|
||||
self.check_health()
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
actor = await self._idle_actors.get()
|
||||
actor_is_alive = True
|
||||
original_actor = actor
|
||||
try:
|
||||
ret = await actor.encode.remote(
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
except ActorDiedError as e:
|
||||
# If the actor is dead, we first try to reinitialize it.
|
||||
logger.warning("%s died with ActorDiedError, reinitializing.",
|
||||
actor,
|
||||
exc_info=e)
|
||||
actor = self._init_actor()
|
||||
try:
|
||||
ret = await actor.encode.remote(
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
except ActorDiedError as e:
|
||||
logger.error(
|
||||
"%s died for second time in a row, marking "
|
||||
"RayTokenizerGroupPool as unhealthy.", actor)
|
||||
actor_is_alive = False
|
||||
if not self._exception:
|
||||
self._exception = e
|
||||
self.check_health()
|
||||
finally:
|
||||
self._finalize_encode(actor, original_actor, actor_is_alive)
|
||||
return ret
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
return self._local_tokenizer_group.get_max_input_len(lora_request)
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
def check_health(self):
|
||||
if self._exception:
|
||||
raise RuntimeError(
|
||||
"TokenizerGroupPool is unhealthy.") from self._exception
|
||||
|
||||
|
||||
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
||||
"""Copy over all current process environment variables to the runtime_env.
|
||||
|
||||
The variables in runtime_env will take precedence over the current process
|
||||
environment variables.
|
||||
|
||||
runtime_env will be modified in place."""
|
||||
env_vars = os.environ.copy()
|
||||
runtime_env.setdefault("env_vars", {})
|
||||
env_vars.update(runtime_env["env_vars"])
|
||||
runtime_env["env_vars"] = env_vars
|
||||
@ -81,9 +81,7 @@ class AsyncLLM(EngineClient):
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
@ -32,7 +32,6 @@ from vllm.v1.utils import report_usage_stats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
@ -74,9 +73,7 @@ class LLMEngine:
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
@ -258,21 +255,12 @@ class LLMEngine:
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
group_type: type[_G] = BaseTokenizerGroup,
|
||||
) -> _G:
|
||||
tokenizer_group = self.tokenizer
|
||||
|
||||
if tokenizer_group is None:
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
if not isinstance(tokenizer_group, group_type):
|
||||
raise TypeError("Invalid type of tokenizer group. "
|
||||
f"Expected type: {group_type}, but "
|
||||
f"found type: {type(tokenizer_group)}")
|
||||
|
||||
return tokenizer_group
|
||||
return self.tokenizer
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Optional, Union
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
@ -225,7 +225,7 @@ class OutputProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
tokenizer: TokenizerGroup,
|
||||
log_stats: bool,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
@ -31,7 +31,7 @@ class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
tokenizer: TokenizerGroup,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
|
||||
@ -61,9 +61,7 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
self.vllm_config = vllm_config
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
|
||||
|
||||
@ -35,9 +35,7 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
|
||||
self.disable_any_whitespace = False
|
||||
backend_options = GuidedDecodingParams(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user