[Core] Introduce SPMD worker execution using Ray accelerated DAG (#6032)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
This commit is contained in:
Rui Qiao 2024-07-17 22:27:09 -07:00 committed by GitHub
parent d25877dd9b
commit 61e592747c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 216 additions and 119 deletions

View File

@ -84,6 +84,8 @@ steps:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
@ -108,6 +110,7 @@ steps:
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py

View File

@ -6,6 +6,7 @@ from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig,
@ -414,6 +415,9 @@ class LLMEngine:
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
else:
from vllm.executor.gpu_executor import GPUExecutor
@ -426,6 +430,7 @@ class LLMEngine:
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
def __reduce__(self):

View File

@ -34,6 +34,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
@ -261,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

View File

@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor):
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:

View File

@ -1,6 +1,5 @@
import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@ -23,12 +22,30 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
@ -40,11 +57,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
@ -110,21 +123,24 @@ class RayGPUExecutor(DistributedGPUExecutor):
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
@ -254,9 +270,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs[0]
def _run_workers(
self,
method: str,
@ -266,7 +296,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
@ -281,6 +310,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers:
raise NotImplementedError(
@ -289,71 +322,69 @@ class RayGPUExecutor(DistributedGPUExecutor):
count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
else islice(all_kwargs, first_worker_args_index, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self):
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
from packaging import version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
@ -365,23 +396,47 @@ class RayGPUExecutor(DistributedGPUExecutor):
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.
bind( # type: ignore[attr-defined]
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req)
outputs = await dag_future
return outputs[0]
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
@ -415,8 +470,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return results[-1]
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)

View File

@ -1,8 +1,8 @@
import pickle
from typing import List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase
@ -31,16 +31,18 @@ try:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model()
output = pickle.dumps(output)
return output
return self.worker._execute_model_spmd(execute_model_req)
ray_import_err = None

View File

@ -1,11 +1,11 @@
import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
@ -30,7 +30,7 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
@ -72,10 +72,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
@ -270,7 +269,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
@ -293,26 +291,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
@ -324,36 +316,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self):
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
from packaging import version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
@ -363,7 +347,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""

View File

@ -281,6 +281,33 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# list to conform to interface.
return output
def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, (
"_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
class WorkerWrapperBase:
"""
@ -296,7 +323,7 @@ class WorkerWrapperBase:
trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
@ -323,7 +350,9 @@ class WorkerWrapperBase:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
def execute_method(self, method, *args, **kwargs):
try: