Ray placement group support (#397)

This commit is contained in:
Antoni Baum 2023-07-19 22:49:31 -07:00 committed by GitHub
parent 8c4b2592fb
commit 9925c17940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 187 additions and 116 deletions

View File

@ -1,6 +1,6 @@
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray ray >= 2.5.1
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0

View File

@ -226,14 +226,14 @@ class AsyncLLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster( distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, *engine_configs,
distributed_init_method, distributed_init_method,
devices, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine

View File

@ -1,11 +1,12 @@
import time import time
from typing import Any, List, Optional from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.worker.worker import Worker
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
@ -54,7 +61,7 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[DeviceID]], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
@ -85,31 +92,73 @@ class LLMEngine:
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
worker_cls = ray.remote( self._init_workers_ray(placement_group)
num_cpus=0, else:
num_gpus=1, self._init_workers(distributed_init_method)
resources={node_resource: 1e-3},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
0,
distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup"):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
)(RayWorker).remote()
self.workers.append(worker)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
None,
None,
))
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
@ -152,11 +201,12 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, placement_group = initialize_cluster(
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, engine = cls(*engine_configs,
distributed_init_method, distributed_init_method,
devices, placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
@ -326,9 +376,10 @@ class LLMEngine:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
executor = executor.remote executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
all_outputs.append(output) all_outputs.append(output)

View File

@ -1,15 +1,35 @@
import socket import socket
from typing import List, Optional, Tuple from typing import Optional, Tuple, TYPE_CHECKING
try:
import ray
except ImportError:
ray = None
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
# rank, node resource (node IP), device id try:
DeviceID = Tuple[int, Optional[str], int] import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
self.worker = None
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)
except ImportError:
ray = None
TorchDistributedWorker = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port(): def get_open_port():
@ -22,7 +42,7 @@ def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
@ -52,63 +72,36 @@ def initialize_cluster(
# We need to setup the distributed init method to make sure # We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly. # the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}" distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]] return distributed_init_method, None
return distributed_init_method, all_stage_devices
# Assume we have a uniform cluster that each node has the same number of current_placement_group = ray.util.get_current_placement_group()
# GPUs for now. if current_placement_group:
valid_node_resources = [] # We are in a placement group
num_devices_per_node = None bundles = current_placement_group.bundle_specs
for node in ray.nodes(): # Verify that we can use the placement group.
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0: gpu_bundles = 0
continue for bundle in bundles:
if num_devices_per_node is None: assert bundle.get("GPU", 0) > 1, (
num_devices_per_node = node["Resources"]["GPU"] "Placement group bundles cannot have more than 1 GPU")
else: if bundle.get("GPU", 0):
assert num_devices_per_node == node["Resources"]["GPU"], ( gpu_bundles += 1
"The number of GPUs per node is not uniform.") if parallel_config.world_size > gpu_bundles:
for key in node["Resources"]:
if key.startswith("node:"):
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError( raise ValueError(
"The number of required GPUs exceeds the total number of " "The number of required GPUs exceeds the total number of "
"available GPUs.") "available GPUs in the placement group.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else: else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0: num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError( raise ValueError(
"The number of GPUs per node is not divisible by the number " "The number of required GPUs exceeds the total number of "
"of tensor parallelism.") "available GPUs in the cluster.")
# Create a new placement group
current_placement_group = ray.util.placement_group([{
"GPU": 1
}] * parallel_config.world_size)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
# Assign GPUs to pipeline stages. return None, current_placement_group
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = get_open_port()
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices

View File

@ -1,7 +1,9 @@
"""A GPU worker class.""" """A GPU worker class."""
from typing import Dict, List, Tuple import os
from typing import Dict, List, Tuple, Optional
import torch import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@ -27,8 +29,8 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
rank: int, rank: Optional[int] = None,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
@ -36,19 +38,6 @@ class Worker:
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
# Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
@ -57,6 +46,31 @@ class Worker:
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(self.model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
@torch.inference_mode() @torch.inference_mode()
def profile_num_available_blocks( def profile_num_available_blocks(
self, self,
@ -294,15 +308,28 @@ class Worker:
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
init_method=distributed_init_method, init_method=distributed_init_method,
) )
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size, initialize_model_parallel(parallel_config.tensor_parallel_size,