[V0 Deprecation] Remove vllm.worker and update according imports (#25901)

This commit is contained in:
Aaron Pham 2025-09-29 19:26:11 -04:00 committed by GitHub
parent 2e4fe48c37
commit 6a113d9aed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 276 additions and 327 deletions

View File

@ -10,7 +10,7 @@ from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor from vllm.v1.executor.abstract import UniProcExecutor
from vllm.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
MODEL_REF = "facebook/opt-125m" MODEL_REF = "facebook/opt-125m"

View File

@ -36,7 +36,6 @@ ALLOWED_FILES = {
'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py', 'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py',
'benchmarks/cutlass_benchmarks/sparse_benchmarks.py', 'benchmarks/cutlass_benchmarks/sparse_benchmarks.py',
# cloudpickle # cloudpickle
'vllm/worker/worker_base.py',
'vllm/executor/mp_distributed_executor.py', 'vllm/executor/mp_distributed_executor.py',
'vllm/executor/ray_distributed_executor.py', 'vllm/executor/ray_distributed_executor.py',
'vllm/entrypoints/llm.py', 'vllm/entrypoints/llm.py',

View File

@ -19,7 +19,7 @@ from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import make_async from vllm.utils import make_async
from vllm.v1.outputs import PoolerOutput, SamplerOutput from vllm.v1.outputs import PoolerOutput, SamplerOutput
from vllm.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput

View File

@ -19,7 +19,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -110,17 +110,7 @@ class CudaPlatformBase(Platform):
model_config = vllm_config.model_config model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config: parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not supported on vLLM V0.")
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:

View File

@ -327,17 +327,7 @@ class RocmPlatform(Platform):
cache_config.block_size = 16 cache_config.block_size = 16
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config: parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
if not use_v1:
raise NotImplementedError(
"Speculative decoding is not supported on vLLM V0.")
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
else:
if use_v1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled. # Aiter rms norm perform best when CUDA Graph capture is enabled.
if (use_v1 and use_aiter_rms_norm and not is_eager_execution if (use_v1 and use_aiter_rms_norm and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops): and "-rms_norm" not in compilation_config.custom_ops):

View File

@ -41,7 +41,7 @@ from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -1,23 +1,35 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from __future__ import annotations
import os
from typing import Any, Callable, Optional, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 from vllm.v1.outputs import SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R")
class WorkerBase(WorkerBaseV0):
""" @warn_for_unimplemented_methods
Abstract class for v1 worker, mainly define some methods for v1. class WorkerBase:
For methods shared by v0 and v1, define them in v0 WorkerBase """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
""" """
def __init__( def __init__(
@ -27,7 +39,7 @@ class WorkerBase(WorkerBaseV0):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ) -> None:
""" """
Initialize common worker components. Initialize common worker components.
@ -39,8 +51,21 @@ class WorkerBase(WorkerBaseV0):
is_driver_worker: Whether this worker handles driver is_driver_worker: Whether this worker handles driver
responsibilities responsibilities
""" """
# Configuration storage self.vllm_config = vllm_config
super().__init__(vllm_config=vllm_config) self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform
self.parallel_config.rank = rank self.parallel_config.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
@ -63,3 +88,227 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None: def check_health(self) -> None:
"""Basic health check (override for device-specific checks).""" """Basic health check (override for device-specific checks)."""
return return
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
def get_model(self) -> nn.Module:
raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[list[SamplerOutput]]:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> set[int]:
raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None
self.vllm_config: Optional[VllmConfig] = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(
self,
envs_list: list[dict[str, str]],
) -> None:
envs = envs_list[self.rpc_rank]
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)
from vllm.plugins import load_general_plugins
load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
raise ValueError(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501
)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def __getattr__(self, attr):
return getattr(self.worker, attr)

View File

@ -1,279 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
Union)
import cloudpickle
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.v1.outputs import SamplerOutput
logger = init_logger(__name__)
_R = TypeVar("_R")
@warn_for_unimplemented_methods
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
def get_model(self) -> nn.Module:
raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> Set[int]:
raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None
self.vllm_config: Optional[VllmConfig] = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
def shutdown(self) -> None:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(self, envs_list: List[Dict[str,
str]]) -> None:
envs = envs_list[self.rpc_rank]
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)
from vllm.plugins import load_general_plugins
load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def __getattr__(self, attr):
return getattr(self.worker, attr)