mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[Core] Allow specifying custom Executor (#6557)
This commit is contained in:
parent
2e26564259
commit
7bd82002ae
@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
|
|||||||
return TokenizerPoolConfig(pool_size=1,
|
return TokenizerPoolConfig(pool_size=1,
|
||||||
pool_type="ray",
|
pool_type="ray",
|
||||||
extra_config={})
|
extra_config={})
|
||||||
|
if isinstance(tokenizer_group_type, type):
|
||||||
|
return TokenizerPoolConfig(pool_size=1,
|
||||||
|
pool_type=tokenizer_group_type,
|
||||||
|
extra_config={})
|
||||||
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
91
tests/engine/test_custom_executor.py
Normal file
91
tests/engine/test_custom_executor.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
class Mock:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGPUExecutor(GPUExecutor):
|
||||||
|
|
||||||
|
def execute_model(self, *args, **kwargs):
|
||||||
|
# Drop marker to show that this was ran
|
||||||
|
with open(".marker", "w"):
|
||||||
|
...
|
||||||
|
return super().execute_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGPUExecutorAsync(GPUExecutorAsync):
|
||||||
|
|
||||||
|
async def execute_model_async(self, *args, **kwargs):
|
||||||
|
with open(".marker", "w"):
|
||||||
|
...
|
||||||
|
return await super().execute_model_async(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor_type_checking(model):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
engine_args = EngineArgs(model=model,
|
||||||
|
distributed_executor_backend=Mock)
|
||||||
|
LLMEngine.from_engine_args(engine_args)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
distributed_executor_backend=Mock)
|
||||||
|
AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||||
|
AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor(model, tmpdir):
|
||||||
|
cwd = os.path.abspath(".")
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
try:
|
||||||
|
assert not os.path.exists(".marker")
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
|
||||||
|
engine.add_request("0", "foo", sampling_params)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
assert os.path.exists(".marker")
|
||||||
|
finally:
|
||||||
|
os.chdir(cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor_async(model, tmpdir):
|
||||||
|
cwd = os.path.abspath(".")
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
try:
|
||||||
|
assert not os.path.exists(".marker")
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
|
||||||
|
async def t():
|
||||||
|
stream = await engine.add_request("0", "foo", sampling_params)
|
||||||
|
async for x in stream:
|
||||||
|
...
|
||||||
|
|
||||||
|
asyncio.run(t())
|
||||||
|
|
||||||
|
assert os.path.exists(".marker")
|
||||||
|
finally:
|
||||||
|
os.chdir(cwd)
|
||||||
@ -7,17 +7,28 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
|
||||||
|
get_tokenizer_group)
|
||||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||||
RayTokenizerGroupPool)
|
RayTokenizerGroupPool)
|
||||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
|
||||||
TokenizerGroup)
|
|
||||||
|
|
||||||
from ..conftest import get_tokenizer_pool_config
|
from ..conftest import get_tokenizer_pool_config
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTokenizerGroup(TokenizerGroup):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._i = 0
|
||||||
|
|
||||||
|
def encode(self, *args, **kwargs):
|
||||||
|
self._i += 1
|
||||||
|
return super().encode(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
@pytest.mark.parametrize("tokenizer_group_type",
|
||||||
|
[None, "ray", CustomTokenizerGroup])
|
||||||
async def test_tokenizer_group(tokenizer_group_type):
|
async def test_tokenizer_group(tokenizer_group_type):
|
||||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
tokenizer_group = get_tokenizer_group(
|
tokenizer_group = get_tokenizer_group(
|
||||||
@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
|
|||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
assert tokenizer_group.get_lora_tokenizer(
|
assert tokenizer_group.get_lora_tokenizer(
|
||||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||||
|
if tokenizer_group_type is CustomTokenizerGroup:
|
||||||
|
assert tokenizer_group._i > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||||
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
|
BaseTokenizerGroup)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -527,11 +530,12 @@ class TokenizerPoolConfig:
|
|||||||
pool type.
|
pool type.
|
||||||
"""
|
"""
|
||||||
pool_size: int
|
pool_size: int
|
||||||
pool_type: str
|
pool_type: Union[str, Type["BaseTokenizerGroup"]]
|
||||||
extra_config: dict
|
extra_config: dict
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.pool_type not in ("ray", ):
|
if self.pool_type not in ("ray", ) and not isinstance(
|
||||||
|
self.pool_type, type):
|
||||||
raise ValueError(f"Unknown pool type: {self.pool_type}")
|
raise ValueError(f"Unknown pool type: {self.pool_type}")
|
||||||
if not isinstance(self.extra_config, dict):
|
if not isinstance(self.extra_config, dict):
|
||||||
raise ValueError("extra_config must be a dictionary.")
|
raise ValueError("extra_config must be a dictionary.")
|
||||||
@ -661,7 +665,8 @@ class ParallelConfig:
|
|||||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||||
ray_workers_use_nsight: bool = False,
|
ray_workers_use_nsight: bool = False,
|
||||||
placement_group: Optional["PlacementGroup"] = None,
|
placement_group: Optional["PlacementGroup"] = None,
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[Union[
|
||||||
|
str, Type["ExecutorBase"]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.pipeline_parallel_size = pipeline_parallel_size
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
@ -676,7 +681,7 @@ class ParallelConfig:
|
|||||||
if worker_use_ray:
|
if worker_use_ray:
|
||||||
if self.distributed_executor_backend is None:
|
if self.distributed_executor_backend is None:
|
||||||
self.distributed_executor_backend = "ray"
|
self.distributed_executor_backend = "ray"
|
||||||
elif self.distributed_executor_backend != "ray":
|
elif not self.use_ray:
|
||||||
raise ValueError(f"worker-use-ray can't be used with "
|
raise ValueError(f"worker-use-ray can't be used with "
|
||||||
f"distributed executor backend "
|
f"distributed executor backend "
|
||||||
f"'{self.distributed_executor_backend}'.")
|
f"'{self.distributed_executor_backend}'.")
|
||||||
@ -711,12 +716,25 @@ class ParallelConfig:
|
|||||||
self._verify_args()
|
self._verify_args()
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ray(self) -> bool:
|
||||||
|
return self.distributed_executor_backend == "ray" or (
|
||||||
|
isinstance(self.distributed_executor_backend, type)
|
||||||
|
and self.distributed_executor_backend.uses_ray)
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
if self.distributed_executor_backend not in ("ray", "mp", None):
|
# Lazy import to avoid circular import
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
|
||||||
|
if self.distributed_executor_backend not in (
|
||||||
|
"ray", "mp", None) and not (isinstance(
|
||||||
|
self.distributed_executor_backend, type) and issubclass(
|
||||||
|
self.distributed_executor_backend, ExecutorBase)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized distributed executor backend. Supported values "
|
"Unrecognized distributed executor backend "
|
||||||
"are 'ray' or 'mp'.")
|
f"{self.distributed_executor_backend}. Supported "
|
||||||
if self.distributed_executor_backend == "ray":
|
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
|
||||||
|
if self.use_ray:
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
ray_utils.assert_ray_available()
|
ray_utils.assert_ray_available()
|
||||||
if is_hip():
|
if is_hip():
|
||||||
@ -724,8 +742,7 @@ class ParallelConfig:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Disabled the custom all-reduce kernel because it is not "
|
"Disabled the custom all-reduce kernel because it is not "
|
||||||
"supported on AMD GPUs.")
|
"supported on AMD GPUs.")
|
||||||
if self.ray_workers_use_nsight and (
|
if self.ray_workers_use_nsight and not self.use_ray:
|
||||||
not self.distributed_executor_backend == "ray"):
|
|
||||||
raise ValueError("Unable to use nsight profiling unless workers "
|
raise ValueError("Unable to use nsight profiling unless workers "
|
||||||
"run with Ray.")
|
"run with Ray.")
|
||||||
|
|
||||||
|
|||||||
@ -2,16 +2,21 @@ import argparse
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig,
|
PromptAdapterConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, TokenizerPoolConfig)
|
SpeculativeConfig, TokenizerPoolConfig)
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
|
BaseTokenizerGroup)
|
||||||
|
|
||||||
|
|
||||||
def nullable_str(val: str):
|
def nullable_str(val: str):
|
||||||
if not val or val == "None":
|
if not val or val == "None":
|
||||||
@ -36,7 +41,11 @@ class EngineArgs:
|
|||||||
seed: int = 0
|
seed: int = 0
|
||||||
max_model_len: Optional[int] = None
|
max_model_len: Optional[int] = None
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
distributed_executor_backend: Optional[str] = None
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
|
# is intended for expert use only. The API may change without
|
||||||
|
# notice.
|
||||||
|
distributed_executor_backend: Optional[Union[str,
|
||||||
|
Type[ExecutorBase]]] = None
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
@ -62,7 +71,10 @@ class EngineArgs:
|
|||||||
max_seq_len_to_capture: int = 8192
|
max_seq_len_to_capture: int = 8192
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
tokenizer_pool_size: int = 0
|
tokenizer_pool_size: int = 0
|
||||||
tokenizer_pool_type: str = "ray"
|
# Note: Specifying a tokenizer pool by passing a class
|
||||||
|
# is intended for expert use only. The API may change without
|
||||||
|
# notice.
|
||||||
|
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
||||||
tokenizer_pool_extra_config: Optional[dict] = None
|
tokenizer_pool_extra_config: Optional[dict] = None
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
max_loras: int = 1
|
max_loras: int = 1
|
||||||
|
|||||||
@ -7,12 +7,13 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.engine.metrics import StatLoggerBase
|
from vllm.engine.metrics import StatLoggerBase
|
||||||
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||||
from vllm.inputs import LLMInputs, PromptInputs
|
from vllm.inputs import LLMInputs, PromptInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -385,25 +386,19 @@ class AsyncLLMEngine:
|
|||||||
self._request_tracker: RequestTracker
|
self._request_tracker: RequestTracker
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(
|
def _get_executor_cls(
|
||||||
cls,
|
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
||||||
engine_args: AsyncEngineArgs,
|
|
||||||
start_engine_loop: bool = True,
|
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
|
||||||
) -> "AsyncLLMEngine":
|
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
|
||||||
# Create the engine configs.
|
|
||||||
engine_config = engine_args.create_engine_config()
|
|
||||||
|
|
||||||
if engine_args.engine_use_ray:
|
|
||||||
from vllm.executor import ray_utils
|
|
||||||
ray_utils.assert_ray_available()
|
|
||||||
|
|
||||||
distributed_executor_backend = (
|
distributed_executor_backend = (
|
||||||
engine_config.parallel_config.distributed_executor_backend)
|
engine_config.parallel_config.distributed_executor_backend)
|
||||||
|
if isinstance(distributed_executor_backend, type):
|
||||||
if engine_config.device_config.device_type == "neuron":
|
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
|
||||||
|
raise TypeError(
|
||||||
|
"distributed_executor_backend must be a subclass of "
|
||||||
|
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
||||||
|
if distributed_executor_backend.uses_ray: # type: ignore
|
||||||
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
executor_class = distributed_executor_backend
|
||||||
|
elif engine_config.device_config.device_type == "neuron":
|
||||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||||
executor_class = NeuronExecutorAsync
|
executor_class = NeuronExecutorAsync
|
||||||
elif engine_config.device_config.device_type == "tpu":
|
elif engine_config.device_config.device_type == "tpu":
|
||||||
@ -442,9 +437,29 @@ class AsyncLLMEngine:
|
|||||||
else:
|
else:
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
executor_class = GPUExecutorAsync
|
executor_class = GPUExecutorAsync
|
||||||
|
return executor_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_engine_args(
|
||||||
|
cls,
|
||||||
|
engine_args: AsyncEngineArgs,
|
||||||
|
start_engine_loop: bool = True,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
|
) -> "AsyncLLMEngine":
|
||||||
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
|
# Create the engine configs.
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
if engine_args.engine_use_ray:
|
||||||
|
from vllm.executor import ray_utils
|
||||||
|
ray_utils.assert_ray_available()
|
||||||
|
|
||||||
|
executor_class = cls._get_executor_cls(engine_config)
|
||||||
|
|
||||||
# Create the async LLM engine.
|
# Create the async LLM engine.
|
||||||
engine = cls(
|
engine = cls(
|
||||||
distributed_executor_backend == "ray",
|
executor_class.uses_ray,
|
||||||
engine_args.engine_use_ray,
|
engine_args.engine_use_ray,
|
||||||
**engine_config.to_dict(),
|
**engine_config.to_dict(),
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
ObservabilityConfig, ParallelConfig,
|
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig,
|
PromptAdapterConfig, SchedulerConfig,
|
||||||
SpeculativeConfig)
|
SpeculativeConfig)
|
||||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||||
@ -376,19 +376,20 @@ class LLMEngine:
|
|||||||
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(
|
def _get_executor_cls(cls,
|
||||||
cls,
|
engine_config: EngineConfig) -> Type[ExecutorBase]:
|
||||||
engine_args: EngineArgs,
|
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
|
||||||
) -> "LLMEngine":
|
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
|
||||||
# Create the engine configs.
|
|
||||||
engine_config = engine_args.create_engine_config()
|
|
||||||
distributed_executor_backend = (
|
distributed_executor_backend = (
|
||||||
engine_config.parallel_config.distributed_executor_backend)
|
engine_config.parallel_config.distributed_executor_backend)
|
||||||
# Initialize the cluster and specify the executor class.
|
# Initialize the cluster and specify the executor class.
|
||||||
if engine_config.device_config.device_type == "neuron":
|
if isinstance(distributed_executor_backend, type):
|
||||||
|
if not issubclass(distributed_executor_backend, ExecutorBase):
|
||||||
|
raise TypeError(
|
||||||
|
"distributed_executor_backend must be a subclass of "
|
||||||
|
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||||
|
if distributed_executor_backend.uses_ray: # type: ignore
|
||||||
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
executor_class = distributed_executor_backend
|
||||||
|
elif engine_config.device_config.device_type == "neuron":
|
||||||
from vllm.executor.neuron_executor import NeuronExecutor
|
from vllm.executor.neuron_executor import NeuronExecutor
|
||||||
executor_class = NeuronExecutor
|
executor_class = NeuronExecutor
|
||||||
elif engine_config.device_config.device_type == "tpu":
|
elif engine_config.device_config.device_type == "tpu":
|
||||||
@ -422,6 +423,19 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
executor_class = GPUExecutor
|
executor_class = GPUExecutor
|
||||||
|
return executor_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_engine_args(
|
||||||
|
cls,
|
||||||
|
engine_args: EngineArgs,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
|
) -> "LLMEngine":
|
||||||
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
|
# Create the engine configs.
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
executor_class = cls._get_executor_cls(engine_config)
|
||||||
# Create the LLM engine.
|
# Create the LLM engine.
|
||||||
engine = cls(
|
engine = cls(
|
||||||
**engine_config.to_dict(),
|
**engine_config.to_dict(),
|
||||||
|
|||||||
@ -17,6 +17,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class CPUExecutor(ExecutorBase):
|
class CPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert self.device_config.device_type == "cpu"
|
assert self.device_config.device_type == "cpu"
|
||||||
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||||
|
|||||||
@ -18,6 +18,8 @@ class ExecutorBase(ABC):
|
|||||||
that can execute the model on multiple devices.
|
that can execute the model on multiple devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
uses_ray: bool # whether the executor uses Ray for orchestration.
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
|
|||||||
@ -23,6 +23,8 @@ def create_worker(worker_module_name, worker_class_name, **kwargs):
|
|||||||
|
|
||||||
class GPUExecutor(ExecutorBase):
|
class GPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
"""Initialize the worker and load the model.
|
"""Initialize the worker and load the model.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -25,6 +25,8 @@ logger = init_logger(__name__)
|
|||||||
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||||
"""Python multiprocessing-based multi-GPU executor"""
|
"""Python multiprocessing-based multi-GPU executor"""
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
world_size = self.parallel_config.world_size
|
world_size = self.parallel_config.world_size
|
||||||
|
|||||||
@ -11,6 +11,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class NeuronExecutor(ExecutorBase):
|
class NeuronExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert (self.lora_config is
|
assert (self.lora_config is
|
||||||
None), "LoRA is not supported for Neuron backend."
|
None), "LoRA is not supported for Neuron backend."
|
||||||
|
|||||||
@ -18,6 +18,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class OpenVINOExecutor(ExecutorBase):
|
class OpenVINOExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert self.device_config.device_type == "openvino"
|
assert self.device_config.device_type == "openvino"
|
||||||
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
|
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
|
||||||
|
|||||||
@ -26,6 +26,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class RayGPUExecutor(DistributedGPUExecutor):
|
class RayGPUExecutor(DistributedGPUExecutor):
|
||||||
|
|
||||||
|
uses_ray: bool = True
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
# which optimizes the control plane overhead.
|
# which optimizes the control plane overhead.
|
||||||
@ -47,7 +49,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
||||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||||
|
|
||||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
assert self.uses_ray
|
||||||
placement_group = self.parallel_config.placement_group
|
placement_group = self.parallel_config.placement_group
|
||||||
|
|
||||||
# Disable Ray usage stats collection.
|
# Disable Ray usage stats collection.
|
||||||
@ -75,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
return ray_remote_kwargs
|
return ray_remote_kwargs
|
||||||
|
|
||||||
|
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||||
|
if self.speculative_config is not None:
|
||||||
|
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||||
|
worker_class_name = "create_spec_worker"
|
||||||
|
else:
|
||||||
|
worker_module_name = "vllm.worker.worker"
|
||||||
|
worker_class_name = "Worker"
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
worker_module_name=worker_module_name,
|
||||||
|
worker_class_name=worker_class_name,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
**ray_remote_kwargs):
|
**ray_remote_kwargs):
|
||||||
if (self.parallel_config.tensor_parallel_size == 1
|
if (self.parallel_config.tensor_parallel_size == 1
|
||||||
@ -97,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
# Create the workers.
|
# Create the workers.
|
||||||
driver_ip = get_ip()
|
driver_ip = get_ip()
|
||||||
|
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||||
if not bundle.get("GPU", 0):
|
if not bundle.get("GPU", 0):
|
||||||
continue
|
continue
|
||||||
@ -106,23 +123,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
placement_group_bundle_index=bundle_id,
|
placement_group_bundle_index=bundle_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config is not None:
|
|
||||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
|
||||||
worker_class_name = "create_spec_worker"
|
|
||||||
else:
|
|
||||||
worker_module_name = "vllm.worker.worker"
|
|
||||||
worker_class_name = "Worker"
|
|
||||||
|
|
||||||
worker = ray.remote(
|
worker = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
num_gpus=num_gpus,
|
num_gpus=num_gpus,
|
||||||
scheduling_strategy=scheduling_strategy,
|
scheduling_strategy=scheduling_strategy,
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorkerWrapper).remote(
|
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||||
worker_module_name=worker_module_name,
|
|
||||||
worker_class_name=worker_class_name,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_ray_spmd_worker:
|
if self.use_ray_spmd_worker:
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
@ -133,10 +139,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
# as the resource holder for the driver process.
|
# as the resource holder for the driver process.
|
||||||
self.driver_dummy_worker = worker
|
self.driver_dummy_worker = worker
|
||||||
self.driver_worker = RayWorkerWrapper(
|
self.driver_worker = RayWorkerWrapper(
|
||||||
worker_module_name=worker_module_name,
|
**worker_wrapper_kwargs)
|
||||||
worker_class_name=worker_class_name,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Else, added to the list of workers.
|
# Else, added to the list of workers.
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
@ -378,7 +381,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
f"required, but found {current_version}")
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
from ray.dag import InputNode, MultiOutputNode
|
from ray.dag import InputNode, MultiOutputNode
|
||||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
assert self.parallel_config.use_ray
|
||||||
|
|
||||||
# Right now, compiled DAG requires at least 1 arg. We send
|
# Right now, compiled DAG requires at least 1 arg. We send
|
||||||
# a dummy value for now. It will be fixed soon.
|
# a dummy value for now. It will be fixed soon.
|
||||||
|
|||||||
@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
|||||||
|
|
||||||
class RayXPUExecutor(DistributedGPUExecutor):
|
class RayXPUExecutor(DistributedGPUExecutor):
|
||||||
|
|
||||||
|
uses_ray: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||||
|
return dict(
|
||||||
|
worker_module_name="vllm.worker.xpu_worker",
|
||||||
|
worker_class_name="XPUWorker",
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
**ray_remote_kwargs):
|
**ray_remote_kwargs):
|
||||||
if self.parallel_config.tensor_parallel_size == 1:
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
# Create the workers.
|
# Create the workers.
|
||||||
driver_ip = get_ip()
|
driver_ip = get_ip()
|
||||||
|
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||||
if not bundle.get("GPU", 0):
|
if not bundle.get("GPU", 0):
|
||||||
continue
|
continue
|
||||||
@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
|||||||
num_gpus=num_gpus,
|
num_gpus=num_gpus,
|
||||||
scheduling_strategy=scheduling_strategy,
|
scheduling_strategy=scheduling_strategy,
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorkerWrapper).remote(
|
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||||
worker_module_name="vllm.worker.xpu_worker",
|
|
||||||
worker_class_name="XPUWorker",
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||||
# If the worker is on the same node as the driver, we use it
|
# If the worker is on the same node as the driver, we use it
|
||||||
# as the resource holder for the driver process.
|
# as the resource holder for the driver process.
|
||||||
self.driver_dummy_worker = worker
|
self.driver_dummy_worker = worker
|
||||||
self.driver_worker = RayWorkerWrapper(
|
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
|
||||||
worker_module_name="vllm.worker.xpu_worker",
|
|
||||||
worker_class_name="XPUWorker",
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Else, added to the list of workers.
|
# Else, added to the list of workers.
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
|||||||
f"required, but found {current_version}")
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
from ray.dag import InputNode, MultiOutputNode
|
from ray.dag import InputNode, MultiOutputNode
|
||||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
assert self.parallel_config.use_ray
|
||||||
|
|
||||||
# Right now, compiled DAG requires at least 1 arg. We send
|
# Right now, compiled DAG requires at least 1 arg. We send
|
||||||
# a dummy value for now. It will be fixed soon.
|
# a dummy value for now. It will be fixed soon.
|
||||||
|
|||||||
@ -14,6 +14,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class TPUExecutor(ExecutorBase):
|
class TPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert not self.scheduler_config.chunked_prefill_enabled, (
|
assert not self.scheduler_config.chunked_prefill_enabled, (
|
||||||
"Chunked prefill is not yet supported for TPU backend")
|
"Chunked prefill is not yet supported for TPU backend")
|
||||||
|
|||||||
@ -18,6 +18,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class XPUExecutor(GPUExecutor):
|
class XPUExecutor(GPUExecutor):
|
||||||
|
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.executor.ray_utils import ray
|
from vllm.executor.ray_utils import ray
|
||||||
@ -16,18 +16,22 @@ else:
|
|||||||
|
|
||||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||||
**init_kwargs) -> BaseTokenizerGroup:
|
**init_kwargs) -> BaseTokenizerGroup:
|
||||||
|
tokenizer_cls: Type[BaseTokenizerGroup]
|
||||||
if tokenizer_pool_config is None:
|
if tokenizer_pool_config is None:
|
||||||
return TokenizerGroup(**init_kwargs)
|
tokenizer_cls = TokenizerGroup
|
||||||
if tokenizer_pool_config.pool_type == "ray":
|
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
|
||||||
|
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
|
||||||
|
tokenizer_cls = tokenizer_pool_config.pool_type
|
||||||
|
elif tokenizer_pool_config.pool_type == "ray":
|
||||||
if RayTokenizerGroupPool is None:
|
if RayTokenizerGroupPool is None:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"RayTokenizerGroupPool is not available. Please install "
|
"RayTokenizerGroupPool is not available. Please install "
|
||||||
"the ray package to use the Ray tokenizer group pool.")
|
"the ray package to use the Ray tokenizer group pool.")
|
||||||
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
|
tokenizer_cls = RayTokenizerGroupPool
|
||||||
**init_kwargs)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
|
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
|
||||||
|
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
|
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
|
||||||
|
|||||||
@ -3,12 +3,19 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenizerGroup(ABC):
|
class BaseTokenizerGroup(ABC):
|
||||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||||
|
**init_kwargs) -> "BaseTokenizerGroup":
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def ping(self) -> bool:
|
def ping(self) -> bool:
|
||||||
"""Check if the tokenizer group is alive."""
|
"""Check if the tokenizer group is alive."""
|
||||||
|
|||||||
@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
_worker_cls = TokenizerGroup
|
_worker_cls = TokenizerGroup
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
|
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||||
**init_kwargs) -> "RayTokenizerGroupPool":
|
**init_kwargs) -> "RayTokenizerGroupPool":
|
||||||
|
if not tokenizer_pool_config:
|
||||||
|
raise ValueError("tokenizer_pool_config must not be None.")
|
||||||
ray_actor_options = (tokenizer_pool_config.extra_config or {
|
ray_actor_options = (tokenizer_pool_config.extra_config or {
|
||||||
"num_cpus": 0
|
"num_cpus": 0
|
||||||
})
|
})
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
||||||
get_lora_tokenizer_async,
|
get_lora_tokenizer_async,
|
||||||
@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||||||
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
||||||
capacity=max_num_seqs) if enable_lora else None
|
capacity=max_num_seqs) if enable_lora else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||||
|
**init_kwargs) -> "TokenizerGroup":
|
||||||
|
return cls(**init_kwargs)
|
||||||
|
|
||||||
def ping(self) -> bool:
|
def ping(self) -> bool:
|
||||||
"""Check if the tokenizer group is alive."""
|
"""Check if the tokenizer group is alive."""
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import dataclasses
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -315,14 +315,23 @@ class WorkerWrapperBase:
|
|||||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||||
and class name. Then, when we call `update_environment_variables`, and the
|
and class name. Then, when we call `update_environment_variables`, and the
|
||||||
real initialization happens in `init_worker`.
|
real initialization happens in `init_worker`.
|
||||||
|
|
||||||
|
If worker_class_fn is specified, it will be executed to get the worker
|
||||||
|
class.
|
||||||
|
Otherwise, the worker class will be obtained by dynamically importing it
|
||||||
|
using worker_module_name and worker_class_name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
worker_module_name: str,
|
self,
|
||||||
worker_class_name: str,
|
worker_module_name: str,
|
||||||
trust_remote_code: bool = False) -> None:
|
worker_class_name: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
worker_class_fn: Optional[Callable[[],
|
||||||
|
Type[WorkerBase]]] = None) -> None:
|
||||||
self.worker_module_name = worker_module_name
|
self.worker_module_name = worker_module_name
|
||||||
self.worker_class_name = worker_class_name
|
self.worker_class_name = worker_class_name
|
||||||
|
self.worker_class_fn = worker_class_fn
|
||||||
self.worker: Optional[WorkerBase] = None
|
self.worker: Optional[WorkerBase] = None
|
||||||
if trust_remote_code:
|
if trust_remote_code:
|
||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
@ -348,8 +357,11 @@ class WorkerWrapperBase:
|
|||||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||||
|
|
||||||
mod = importlib.import_module(self.worker_module_name)
|
if self.worker_class_fn:
|
||||||
worker_class = getattr(mod, self.worker_class_name)
|
worker_class = self.worker_class_fn()
|
||||||
|
else:
|
||||||
|
mod = importlib.import_module(self.worker_module_name)
|
||||||
|
worker_class = getattr(mod, self.worker_class_name)
|
||||||
|
|
||||||
self.worker = worker_class(*args, **kwargs)
|
self.worker = worker_class(*args, **kwargs)
|
||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user