[Core] Add MultiprocessingGPUExecutor (#4539)

Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
This commit is contained in:
Nick Hill 2024-05-14 10:38:59 -07:00 committed by GitHub
parent dc72402b57
commit 676a99982f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 225 additions and 39 deletions

View File

@ -34,10 +34,14 @@ steps:
mirror_hardwares: [amd]
commands:
- pytest -v -s distributed/test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests"

View File

@ -20,6 +20,7 @@ import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@ -36,19 +37,21 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
enforce_eager = False
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER":
enforce_eager = True
enforce_eager = backend_by_env_var == "FLASHINFER"
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager)
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -19,6 +19,7 @@ import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -36,6 +37,8 @@ def test_models(
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
@ -53,6 +56,7 @@ def test_models(
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size,
worker_use_ray=True)
tensor_parallel_size=tp_size)
expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501

View File

@ -521,9 +521,7 @@ class ParallelConfig:
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
worker_use_ray: Deprecated, use distributed_executor_backend instead.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
@ -533,22 +531,27 @@ class ParallelConfig:
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
distributed_executor_backend: Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If either
pipeline_parallel_size or tensor_parallel_size is greater than 1,
will default to "ray" if Ray is installed or "mp" otherwise.
"""
def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
worker_use_ray: Optional[bool] = None,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[str] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.distributed_executor_backend = distributed_executor_backend
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
@ -556,14 +559,29 @@ class ParallelConfig:
self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
elif self.distributed_executor_backend != "ray":
raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
if self.distributed_executor_backend is None and self.world_size > 1:
from vllm.executor import ray_utils
ray_found = ray_utils.ray is not None
self.distributed_executor_backend = "ray" if ray_found else "mp"
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
@ -575,7 +593,8 @@ class ParallelConfig:
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if self.ray_workers_use_nsight and not self.worker_use_ray:
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
@ -887,7 +906,8 @@ class SpeculativeConfig:
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
worker_use_ray=target_parallel_config.worker_use_ray,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.

View File

@ -34,6 +34,7 @@ class EngineArgs:
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
@ -221,10 +222,17 @@ class EngineArgs:
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
help='Deprecated, use --distributed-executor-backend=ray.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,

View File

@ -348,27 +348,31 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend.")
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
distributed_executor_backend == "ray",
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,

View File

@ -274,6 +274,8 @@ class LLMEngine:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
@ -282,13 +284,15 @@ class LLMEngine:
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
executor_class = MultiprocessingGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

View File

@ -0,0 +1,140 @@
import asyncio
import os
from functools import partial
from typing import Any, Dict, Optional, Tuple
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
def _init_executor(self) -> None:
assert (
not self.speculative_config
), "Speculative decoding not yet supported for MultiProcGPU backend."
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
map(str, range(world_size))))
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
from torch.cuda import device_count
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
if world_size == 1:
self.workers = []
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args,
**driver_kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
driver_executor = make_async(getattr(self.driver_worker, method))
# Run all the workers asynchronously.
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
for worker in self.workers
]
return await asyncio.gather(*coros)

View File

@ -31,7 +31,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
@ -264,7 +264,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.

View File

@ -44,7 +44,7 @@ try:
except ImportError as e:
logger.warning(
"Failed to import Ray with %r. For distributed inference, "
"Failed to import Ray with %r. For multi-node inference, "
"please install Ray with `pip install ray`.", e)
ray = None # type: ignore
RayWorkerWrapper = None # type: ignore
@ -67,7 +67,7 @@ def initialize_ray_cluster(
"""
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"Ray is not installed. Please install Ray to use multi-node "
"serving.")
# Connect to a ray cluster.