mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 19:31:48 +08:00
[Core] Add MultiprocessingGPUExecutor (#4539)
Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
This commit is contained in:
parent
dc72402b57
commit
676a99982f
@ -34,10 +34,14 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pynccl_library.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
|
||||
- label: Distributed Tests (Multiple Groups)
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -20,6 +20,7 @@ import torch
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
@ -36,19 +37,21 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
enforce_eager = False
|
||||
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
|
||||
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER":
|
||||
enforce_eager = True
|
||||
enforce_eager = backend_by_env_var == "FLASHINFER"
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=enforce_eager)
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=enforce_eager,
|
||||
distributed_executor_backend=distributed_executor_backend)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import torch
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -36,6 +37,8 @@ def test_models(
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
|
||||
|
||||
# Add a chunked prefill config.
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
assert chunked_prefill_token_size != -1
|
||||
@ -53,6 +56,7 @@ def test_models(
|
||||
max_num_seqs=max_num_seqs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=tp_size,
|
||||
worker_use_ray=True)
|
||||
tensor_parallel_size=tp_size)
|
||||
|
||||
expected_lora_output = [
|
||||
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
|
||||
|
||||
@ -521,9 +521,7 @@ class ParallelConfig:
|
||||
Args:
|
||||
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||
tensor_parallel_size: Number of tensor parallel groups.
|
||||
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
worker_use_ray: Deprecated, use distributed_executor_backend instead.
|
||||
max_parallel_loading_workers: Maximum number of multiple batches
|
||||
when load model sequentially. To avoid RAM OOM when using tensor
|
||||
parallel and large models.
|
||||
@ -533,22 +531,27 @@ class ParallelConfig:
|
||||
If None, will use synchronous tokenization.
|
||||
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
|
||||
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
||||
distributed_executor_backend: Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If either
|
||||
pipeline_parallel_size or tensor_parallel_size is greater than 1,
|
||||
will default to "ray" if Ray is installed or "mp" otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: bool,
|
||||
worker_use_ray: Optional[bool] = None,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.distributed_executor_backend = distributed_executor_backend
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
self.tokenizer_pool_config = tokenizer_pool_config
|
||||
@ -556,14 +559,29 @@ class ParallelConfig:
|
||||
self.placement_group = placement_group
|
||||
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
self.worker_use_ray = True
|
||||
if worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif self.distributed_executor_backend != "ray":
|
||||
raise ValueError(f"worker-use-ray can't be used with "
|
||||
f"distributed executor backend "
|
||||
f"'{self.distributed_executor_backend}'.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
from vllm.executor import ray_utils
|
||||
ray_found = ray_utils.ray is not None
|
||||
self.distributed_executor_backend = "ray" if ray_found else "mp"
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
if self.distributed_executor_backend not in ("ray", "mp", None):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend. Supported values "
|
||||
"are 'ray' or 'mp'.")
|
||||
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||
if is_hip():
|
||||
self.disable_custom_all_reduce = True
|
||||
@ -575,7 +593,8 @@ class ParallelConfig:
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported with pipeline parallelism.")
|
||||
if self.ray_workers_use_nsight and not self.worker_use_ray:
|
||||
if self.ray_workers_use_nsight and (
|
||||
not self.distributed_executor_backend == "ray"):
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
||||
@ -887,7 +906,8 @@ class SpeculativeConfig:
|
||||
pipeline_parallel_size=target_parallel_config.
|
||||
pipeline_parallel_size,
|
||||
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
|
||||
worker_use_ray=target_parallel_config.worker_use_ray,
|
||||
distributed_executor_backend=target_parallel_config.
|
||||
distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.
|
||||
|
||||
@ -34,6 +34,7 @@ class EngineArgs:
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
distributed_executor_backend: Optional[str] = None
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
@ -221,10 +222,17 @@ class EngineArgs:
|
||||
' Can be overridden per request via guided_decoding_backend'
|
||||
' parameter.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
help='Use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU.')
|
||||
parser.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
choices=['ray', 'mp'],
|
||||
default=EngineArgs.distributed_executor_backend,
|
||||
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||
'is used, will be automatically set to "ray" if installed '
|
||||
'or "mp" (multiprocessing) otherwise.')
|
||||
parser.add_argument(
|
||||
'--worker-use-ray',
|
||||
action='store_true',
|
||||
help='Deprecated, use --distributed-executor-backend=ray.')
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
|
||||
@ -348,27 +348,31 @@ class AsyncLLMEngine:
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
assert not engine_config.parallel_config.worker_use_ray, (
|
||||
"Ray is not supported with the CPU backend.")
|
||||
assert distributed_executor_backend is None, (
|
||||
"Distributed execution is not supported with the CPU backend.")
|
||||
from vllm.executor.cpu_executor import CPUExecutorAsync
|
||||
executor_class = CPUExecutorAsync
|
||||
elif engine_config.parallel_config.worker_use_ray:
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutorAsync)
|
||||
executor_class = MultiprocessingGPUExecutorAsync
|
||||
else:
|
||||
assert engine_config.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
executor_class = GPUExecutorAsync
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
engine_config.parallel_config.worker_use_ray,
|
||||
distributed_executor_backend == "ray",
|
||||
engine_args.engine_use_ray,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
|
||||
@ -274,6 +274,8 @@ class LLMEngine:
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
|
||||
# Initialize the cluster and specify the executor class.
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
@ -282,13 +284,15 @@ class LLMEngine:
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
elif engine_config.parallel_config.worker_use_ray:
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||
executor_class = RayGPUExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor)
|
||||
executor_class = MultiprocessingGPUExecutor
|
||||
else:
|
||||
assert engine_config.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
|
||||
|
||||
140
vllm/executor/multiproc_gpu_executor.py
Normal file
140
vllm/executor/multiproc_gpu_executor.py
Normal file
@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id, make_async)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
"""Python multiprocessing-based multi-GPU executor"""
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert (
|
||||
not self.speculative_config
|
||||
), "Speculative decoding not yet supported for MultiProcGPU backend."
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
|
||||
map(str, range(world_size))))
|
||||
|
||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||
|
||||
from torch.cuda import device_count
|
||||
assert world_size <= device_count(), (
|
||||
"please set tensor_parallel_size to less than max local gpu count")
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
if world_size == 1:
|
||||
self.workers = []
|
||||
else:
|
||||
result_handler = ResultHandler()
|
||||
self.workers = [
|
||||
ProcessWorkerWrapper(
|
||||
result_handler,
|
||||
partial(
|
||||
self._create_worker,
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)) for rank in range(1, world_size)
|
||||
]
|
||||
|
||||
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
self.driver_worker = self._create_worker(
|
||||
distributed_init_method=distributed_init_method)
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
None)) is not None:
|
||||
worker_monitor.close()
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
# Start the workers first.
|
||||
worker_outputs = [
|
||||
worker.execute_method(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_method = getattr(self.driver_worker, method)
|
||||
driver_worker_output = driver_worker_method(*driver_args,
|
||||
**driver_kwargs)
|
||||
|
||||
# Get the results of the workers.
|
||||
return [driver_worker_output
|
||||
] + [output.get() for output in worker_outputs]
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
if not self.worker_monitor.is_alive():
|
||||
raise RuntimeError("Worker processes are not running")
|
||||
|
||||
|
||||
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
|
||||
DistributedGPUExecutorAsync):
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
driver_executor = make_async(getattr(self.driver_worker, method))
|
||||
|
||||
# Run all the workers asynchronously.
|
||||
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
|
||||
worker.execute_method_async(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
return await asyncio.gather(*coros)
|
||||
@ -31,7 +31,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for RayGPU backend."
|
||||
|
||||
assert self.parallel_config.worker_use_ray
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
@ -264,7 +264,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
assert self.parallel_config.worker_use_ray
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
|
||||
@ -44,7 +44,7 @@ try:
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Failed to import Ray with %r. For distributed inference, "
|
||||
"Failed to import Ray with %r. For multi-node inference, "
|
||||
"please install Ray with `pip install ray`.", e)
|
||||
ray = None # type: ignore
|
||||
RayWorkerWrapper = None # type: ignore
|
||||
@ -67,7 +67,7 @@ def initialize_ray_cluster(
|
||||
"""
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"Ray is not installed. Please install Ray to use multi-node "
|
||||
"serving.")
|
||||
|
||||
# Connect to a ray cluster.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user