vllm/vllm/v1/executor/abstract.py
Cyrus Leung 4d4d6bad19
[Chore] Separate out vllm.utils.importlib (#27022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-10-17 00:48:59 +00:00

143 lines
5.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any
import torch
import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
FailureCallback = Callable[[], None]
class Executor(ExecutorBase):
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
@staticmethod
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = parallel_config.distributed_executor_backend
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}."
)
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.v1.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor,
)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
elif isinstance(distributed_executor_backend, str):
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
if not issubclass(executor_class, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {executor_class}."
)
else:
raise ValueError(
f"Unknown distributed executor backend: {distributed_executor_backend}"
)
return executor_class
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
"""
Register a function to be called if the executor enters a permanent
failed state.
"""
pass
def determine_available_memory(self) -> list[int]: # in bytes
return self.collective_rpc("determine_available_memory")
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec")
def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: bool = False,
) -> list[Any]:
raise NotImplementedError
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc(
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> DraftTokenIds | None:
output = self.collective_rpc("take_draft_token_ids")
return output[0]
@property
def max_concurrent_batches(self) -> int:
return 1
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start,))
class UniProcExecutor(UniProcExecutorV0, Executor):
pass
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return [memory_tensor.item()]