mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 04:08:43 +08:00
[platforms] absorb worker cls difference into platforms folder (#10555)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
446c7806b2
commit
a111d0151f
236
vllm/config.py
236
vllm/config.py
@ -926,56 +926,56 @@ class LoadConfig:
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution.
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
Args:
|
||||
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||
tensor_parallel_size: Number of tensor parallel groups.
|
||||
worker_use_ray: Deprecated, use distributed_executor_backend instead.
|
||||
max_parallel_loading_workers: Maximum number of multiple batches
|
||||
when load model sequentially. To avoid RAM OOM when using tensor
|
||||
parallel and large models.
|
||||
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
||||
fall back to NCCL.
|
||||
tokenizer_pool_config: Config for the tokenizer pool.
|
||||
If None, will use synchronous tokenization.
|
||||
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
|
||||
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
placement_group: ray distributed model workers placement group.
|
||||
distributed_executor_backend: Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, this will default
|
||||
to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
and hpu only support Ray for distributed inference.
|
||||
"""
|
||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: Optional[bool] = None,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
distributed_executor_backend: Optional[Union[
|
||||
str, Type["ExecutorBase"]]] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.distributed_executor_backend = distributed_executor_backend
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
self.tokenizer_pool_config = tokenizer_pool_config
|
||||
self.ray_workers_use_nsight = ray_workers_use_nsight
|
||||
self.placement_group = placement_group
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
# Deprecated, use distributed_executor_backend instead.
|
||||
worker_use_ray: Optional[bool] = None
|
||||
|
||||
if worker_use_ray:
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
# parallel and large models.
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
|
||||
# Disable the custom all-reduce kernel and fall back to NCCL.
|
||||
disable_custom_all_reduce: bool = False
|
||||
|
||||
# Config for the tokenizer pool. If None, will use synchronous tokenization.
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
|
||||
|
||||
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
ray_workers_use_nsight: bool = False
|
||||
|
||||
# ray distributed model workers placement group.
|
||||
placement_group: Optional["PlacementGroup"] = None
|
||||
|
||||
# Backend to use for distributed model
|
||||
# workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
# of pipeline_parallel_size and tensor_parallel_size is less than
|
||||
# or equal to the number of GPUs available, "mp" will be used to
|
||||
# keep processing on a single host. Otherwise, this will default
|
||||
# to "ray" if Ray is installed and fail otherwise. Note that tpu
|
||||
# and hpu only support Ray for distributed inference.
|
||||
distributed_executor_backend: Optional[Union[str,
|
||||
Type["ExecutorBase"]]] = None
|
||||
|
||||
# the full name of the worker class to use. If "auto", the worker class
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
|
||||
world_size: int = field(init=False)
|
||||
|
||||
rank: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
if self.worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif not self.use_ray:
|
||||
@ -1026,7 +1026,6 @@ class ParallelConfig:
|
||||
backend)
|
||||
|
||||
self._verify_args()
|
||||
self.rank: int = 0
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
@ -1059,100 +1058,97 @@ class ParallelConfig:
|
||||
"run with Ray.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration.
|
||||
"""Scheduler configuration."""
|
||||
|
||||
Args:
|
||||
task: The task to use the model for.
|
||||
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||
a single iteration.
|
||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
num_lookahead_slots: The number of slots to allocate per sequence per
|
||||
step, beyond the known token ids. This is used in speculative
|
||||
decoding to store KV activations of tokens which may or may not be
|
||||
accepted.
|
||||
delay_factor: Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
preemption_mode: Whether to perform preemption by swapping or
|
||||
recomputation. If not specified, we determine the mode as follows:
|
||||
We use recomputation by default since it incurs lower overhead than
|
||||
swapping. However, when the sequence group has multiple sequences
|
||||
(e.g., beam search), recomputation is not currently supported. In
|
||||
such a case, we use swapping instead.
|
||||
send_delta_data: Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
when SPMD worker architecture is enabled. I.e.,
|
||||
VLLM_USE_RAY_SPMD_WORKER=1
|
||||
policy: The scheduling policy to use. "fcfs" (default) or "priority".
|
||||
"""
|
||||
task: str = "generate" # The task to use the model for.
|
||||
|
||||
def __init__(self,
|
||||
task: _Task,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
num_lookahead_slots: int = 0,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
is_multimodal_model: bool = False,
|
||||
preemption_mode: Optional[str] = None,
|
||||
num_scheduler_steps: int = 1,
|
||||
multi_step_stream_outputs: bool = False,
|
||||
send_delta_data: bool = False,
|
||||
policy: str = "fcfs") -> None:
|
||||
if max_num_batched_tokens is None:
|
||||
if enable_chunked_prefill:
|
||||
if num_scheduler_steps > 1:
|
||||
# Maximum number of tokens to be processed in a single iteration.
|
||||
max_num_batched_tokens: int = field(default=None) # type: ignore
|
||||
|
||||
# Maximum number of sequences to be processed in a single iteration.
|
||||
max_num_seqs: int = 128
|
||||
|
||||
# Maximum length of a sequence (including prompt and generated text).
|
||||
max_model_len: int = 8192
|
||||
|
||||
# The number of slots to allocate per sequence per
|
||||
# step, beyond the known token ids. This is used in speculative
|
||||
# decoding to store KV activations of tokens which may or may not be
|
||||
# accepted.
|
||||
num_lookahead_slots: int = 0
|
||||
|
||||
# Apply a delay (of delay factor multiplied by previous
|
||||
# prompt latency) before scheduling next prompt.
|
||||
delay_factor: float = 0.0
|
||||
|
||||
# If True, prefill requests can be chunked based
|
||||
# on the remaining max_num_batched_tokens.
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
|
||||
# Whether to perform preemption by swapping or
|
||||
# recomputation. If not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
num_scheduler_steps: int = 1
|
||||
|
||||
multi_step_stream_outputs: bool = False
|
||||
|
||||
# Private API. If used, scheduler sends delta data to
|
||||
# workers instead of an entire data. It should be enabled only
|
||||
# when SPMD worker architecture is enabled. I.e.,
|
||||
# VLLM_USE_RAY_SPMD_WORKER=1
|
||||
send_delta_data: bool = False
|
||||
|
||||
# The scheduling policy to use. "fcfs" (default) or "priority".
|
||||
policy: str = "fcfs"
|
||||
|
||||
chunked_prefill_enabled: bool = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
if self.num_scheduler_steps > 1:
|
||||
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
|
||||
# for now. Have max_num_batched_tokens set to max_model_len
|
||||
# so we don't reject sequences on account of a short
|
||||
# max_num_batched_tokens.
|
||||
max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||
else:
|
||||
# It is the values that have the best balance between ITL
|
||||
# and TTFT on A100. Note it is not optimized for throughput.
|
||||
max_num_batched_tokens = 512
|
||||
self.max_num_batched_tokens = 512
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||
|
||||
if task == "embedding":
|
||||
if self.task == "embedding":
|
||||
# For embedding, choose specific value for higher throughput
|
||||
max_num_batched_tokens = max(
|
||||
max_num_batched_tokens,
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
if is_multimodal_model:
|
||||
if self.is_multimodal_model:
|
||||
# The value needs to be at least the number of multimodal tokens
|
||||
max_num_batched_tokens = max(
|
||||
max_num_batched_tokens,
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
|
||||
if enable_chunked_prefill:
|
||||
if self.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||
self.max_num_batched_tokens)
|
||||
|
||||
self.task: Final = task
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
self.delay_factor = delay_factor
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.preemption_mode = preemption_mode
|
||||
self.num_scheduler_steps = num_scheduler_steps
|
||||
self.multi_step_stream_outputs = multi_step_stream_outputs
|
||||
self.send_delta_data = send_delta_data
|
||||
self.policy = policy
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@ -2293,10 +2289,10 @@ class VllmConfig:
|
||||
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
scheduler_config: SchedulerConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
|
||||
init=True)
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
|
||||
init=True)
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
|
||||
@ -191,6 +191,7 @@ class EngineArgs:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
override_pooler_config: Optional[PoolerConfig] = None
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
worker_cls: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
@ -887,6 +888,12 @@ class EngineArgs:
|
||||
'compilers, using -O without space is also '
|
||||
'supported. -O3 is equivalent to -O 3.')
|
||||
|
||||
parser.add_argument(
|
||||
'--worker-cls',
|
||||
type=str,
|
||||
default="auto",
|
||||
help='The worker class to use for distributed execution.')
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -999,7 +1006,9 @@ class EngineArgs:
|
||||
self.tokenizer_pool_extra_config,
|
||||
),
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
distributed_executor_backend=self.distributed_executor_backend)
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
)
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
use_long_context = max_model_len > 32768
|
||||
|
||||
@ -115,13 +115,8 @@ class CPUExecutor(ExecutorBase):
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
):
|
||||
worker_module_name = "vllm.worker.cpu_worker"
|
||||
worker_class_name = "CPUWorker"
|
||||
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
)
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
assert self.distributed_init_method is not None
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@ -8,19 +8,14 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(worker_module_name: str, worker_class_name: str,
|
||||
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
|
||||
**kwargs):
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
def create_worker(**kwargs):
|
||||
vllm_config = kwargs.get("vllm_config")
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||
wrapper.init_worker(**kwargs)
|
||||
return wrapper.worker
|
||||
|
||||
@ -57,43 +52,11 @@ class GPUExecutor(ExecutorBase):
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_worker"
|
||||
worker_class_name = "MultiStepWorker"
|
||||
elif self.speculative_config:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def _get_create_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Dict:
|
||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method)
|
||||
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
worker_kwargs.update(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
|
||||
return worker_kwargs
|
||||
|
||||
def _create_worker(self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
return create_worker(**self._get_create_worker_kwargs(
|
||||
return create_worker(**self._get_worker_kwargs(
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method))
|
||||
|
||||
@ -48,10 +48,7 @@ class HPUExecutor(ExecutorBase):
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name="vllm.worker.hpu_worker",
|
||||
worker_class_name="HPUWorker",
|
||||
)
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return wrapper.worker
|
||||
|
||||
@ -90,7 +90,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
result_handler,
|
||||
partial(
|
||||
create_worker,
|
||||
**self._get_create_worker_kwargs(
|
||||
**self._get_worker_kwargs(
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -25,10 +26,10 @@ class NeuronExecutor(ExecutorBase):
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.neuron_worker import NeuronWorker
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = NeuronWorker(
|
||||
self.driver_worker = wrapper.init_worker(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
get_open_port, make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -38,15 +39,12 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.openvino_worker import OpenVINOWorker
|
||||
|
||||
assert (
|
||||
self.parallel_config.world_size == 1
|
||||
), "OpenVINOExecutor only supports single CPU socket currently."
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = OpenVINOWorker(
|
||||
self.driver_worker = wrapper.init_worker(
|
||||
ov_core=self.ov_core,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
|
||||
@ -91,17 +91,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# child class could overwrite this to return actual env vars.
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
return self._env_vars_for_all_workers
|
||||
@ -135,7 +124,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
@ -150,7 +138,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
@ -161,7 +149,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -2,8 +2,7 @@ import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type)
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -18,7 +17,6 @@ from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, get_vllm_instance_id,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -81,33 +79,6 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
def finish_measurements(self):
|
||||
self._run_workers("finish_measurements")
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self
|
||||
) -> Tuple[str, str, Optional[Callable[[],
|
||||
Type[WorkerBase]]]]: # noqa: F821
|
||||
worker_class_fn = None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step execution is not implemented for HPU")
|
||||
elif self.speculative_config:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not implemented for HPU")
|
||||
else:
|
||||
worker_module_name = "vllm.worker.hpu_worker"
|
||||
worker_class_name = "HPUWorker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
@ -128,7 +99,6 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("HPU", 0):
|
||||
continue
|
||||
@ -144,7 +114,7 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
resources={'HPU': num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
@ -155,7 +125,7 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -69,14 +69,6 @@ class RayTPUExecutor(TPUExecutor):
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
assert self.speculative_config is None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_tpu_worker"
|
||||
worker_class_name = "MultiStepTPUWorker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.tpu_worker"
|
||||
worker_class_name = "TPUWorker"
|
||||
|
||||
# GKE does not fetch environment information from metadata server
|
||||
# and instead sets these from within the Ray process. Therefore we
|
||||
# need to override the Ray environment variables manually.
|
||||
@ -95,11 +87,7 @@ class RayTPUExecutor(TPUExecutor):
|
||||
resources={"TPU": 1},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
if override_env:
|
||||
worker.override_env_vars.remote(override_env)
|
||||
|
||||
@ -109,10 +97,7 @@ class RayTPUExecutor(TPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
@ -6,7 +6,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -22,17 +21,6 @@ class XPUExecutor(GPUExecutor):
|
||||
|
||||
GPUExecutor._init_executor(self)
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
else:
|
||||
worker_module_name = "vllm.worker.xpu_worker"
|
||||
worker_class_name = "XPUWorker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
|
||||
@ -84,3 +84,5 @@ class CpuPlatform(Platform):
|
||||
"distributed executor backend."),
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
|
||||
@ -4,7 +4,7 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, List, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
@ -16,6 +16,11 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
@ -157,3 +162,17 @@ class CudaPlatform(Platform):
|
||||
" machine has no NVLink equipped.")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
@ -14,3 +21,19 @@ class HpuPlatform(Platform):
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step execution is not implemented for HPU")
|
||||
|
||||
if vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not implemented for HPU")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class NeuronPlatform(Platform):
|
||||
_enum = PlatformEnum.NEURON
|
||||
@ -8,3 +15,10 @@ class NeuronPlatform(Platform):
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "neuron"
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.neuron_worker.NeuronWorker"
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -5,6 +7,11 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -38,3 +45,14 @@ class OpenVinoPlatform(Platform):
|
||||
def is_pin_memory_available(self) -> bool:
|
||||
logger.warning("Pin memory is not supported on OpenViNO.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
assert (
|
||||
parallel_config.world_size == 1
|
||||
), "OpenVINOExecutor only supports single CPU socket currently."
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.openvino_worker.OpenVINOWorker"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,6 +8,11 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
@ -58,3 +64,17 @@ class RocmPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@ -48,3 +48,15 @@ class TpuPlatform(Platform):
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
assert vllm_config.speculative_config is None, \
|
||||
"TPU does not support speculative decoding"
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
|
||||
|
||||
@ -57,6 +57,10 @@ class XPUPlatform(Platform):
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
|
||||
if vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.distributed_executor_backend is not None
|
||||
@ -66,3 +70,5 @@ class XPUPlatform(Platform):
|
||||
" executor backend.",
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
update_environment_variables)
|
||||
resolve_obj_by_qualname, update_environment_variables)
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
@ -411,23 +410,14 @@ class WorkerWrapperBase:
|
||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||
and class name. Then, when we call `update_environment_variables`, and the
|
||||
real initialization happens in `init_worker`.
|
||||
|
||||
If worker_class_fn is specified, it will be executed to get the worker
|
||||
class.
|
||||
Otherwise, the worker class will be obtained by dynamically importing it
|
||||
using worker_module_name and worker_class_name.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_module_name: str,
|
||||
worker_class_name: str,
|
||||
trust_remote_code: bool = False,
|
||||
worker_class_fn: Optional[Callable[[],
|
||||
Type[WorkerBase]]] = None) -> None:
|
||||
self.worker_module_name = worker_module_name
|
||||
self.worker_class_name = worker_class_name
|
||||
self.worker_class_fn = worker_class_fn
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
@ -456,12 +446,8 @@ class WorkerWrapperBase:
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
if self.worker_class_fn:
|
||||
worker_class = self.worker_class_fn()
|
||||
else:
|
||||
mod = importlib.import_module(self.worker_module_name)
|
||||
worker_class = getattr(mod, self.worker_class_name)
|
||||
|
||||
worker_class = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user