mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
Ray placement group support (#397)
This commit is contained in:
parent
8c4b2592fb
commit
9925c17940
@ -1,6 +1,6 @@
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray
|
||||
ray >= 2.5.1
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch >= 2.0.0
|
||||
|
||||
@ -226,14 +226,14 @@ class AsyncLLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config, engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -54,7 +61,7 @@ class LLMEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[DeviceID]],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
@ -85,31 +92,73 @@ class LLMEngine:
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self.workers: List[Worker] = []
|
||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||
for rank, node_resource, _ in stage_devices[0]:
|
||||
worker_cls = Worker
|
||||
if self.parallel_config.worker_use_ray:
|
||||
worker_cls = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-3},
|
||||
)(worker_cls).remote
|
||||
if self.parallel_config.worker_use_ray:
|
||||
self._init_workers_ray(placement_group)
|
||||
else:
|
||||
self._init_workers(distributed_init_method)
|
||||
|
||||
worker = worker_cls(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
# Create the scheduler.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
||||
|
||||
def _init_workers(self, distributed_init_method: str):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
0,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup"):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for bundle in placement_group.bundle_specs:
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
)(RayWorker).remote()
|
||||
self.workers.append(worker)
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
init_torch_dist_process_group(self.workers, backend="nccl")
|
||||
self._run_workers("init_worker",
|
||||
get_all_outputs=True,
|
||||
worker_init_fn=lambda: Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -152,11 +201,12 @@ class LLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
placement_group,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
@ -326,9 +376,10 @@ class LLMEngine:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = executor.remote
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
@ -1,15 +1,35 @@
|
||||
import socket
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
# rank, node resource (node IP), device id
|
||||
DeviceID = Tuple[int, Optional[str], int]
|
||||
try:
|
||||
import ray
|
||||
from ray.air.util.torch_dist import TorchDistributedWorker
|
||||
|
||||
class RayWorker(TorchDistributedWorker):
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.worker = None
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
self.worker = worker_init_fn()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.worker, name)
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
ray = None
|
||||
TorchDistributedWorker = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def get_open_port():
|
||||
@ -22,7 +42,7 @@ def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
ray_address: Optional[str] = None,
|
||||
) -> Tuple[str, List[List[DeviceID]]]:
|
||||
) -> Tuple[str, Optional["PlacementGroup"]]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
|
||||
Args:
|
||||
@ -52,63 +72,36 @@ def initialize_cluster(
|
||||
# We need to setup the distributed init method to make sure
|
||||
# the distributed megatron code (e.g., get world size) works correctly.
|
||||
distributed_init_method = f"tcp://localhost:{port}"
|
||||
all_stage_devices = [[(0, None, 0)]]
|
||||
return distributed_init_method, all_stage_devices
|
||||
return distributed_init_method, None
|
||||
|
||||
# Assume we have a uniform cluster that each node has the same number of
|
||||
# GPUs for now.
|
||||
valid_node_resources = []
|
||||
num_devices_per_node = None
|
||||
for node in ray.nodes():
|
||||
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
|
||||
continue
|
||||
if num_devices_per_node is None:
|
||||
num_devices_per_node = node["Resources"]["GPU"]
|
||||
else:
|
||||
assert num_devices_per_node == node["Resources"]["GPU"], (
|
||||
"The number of GPUs per node is not uniform.")
|
||||
for key in node["Resources"]:
|
||||
if key.startswith("node:"):
|
||||
valid_node_resources.append(key)
|
||||
|
||||
# Verify the parallel config.
|
||||
num_nodes = len(valid_node_resources)
|
||||
if parallel_config.world_size > num_nodes * num_devices_per_node:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs.")
|
||||
if parallel_config.tensor_parallel_size >= num_devices_per_node:
|
||||
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
bundles = current_placement_group.bundle_specs
|
||||
# Verify that we can use the placement group.
|
||||
gpu_bundles = 0
|
||||
for bundle in bundles:
|
||||
assert bundle.get("GPU", 0) > 1, (
|
||||
"Placement group bundles cannot have more than 1 GPU")
|
||||
if bundle.get("GPU", 0):
|
||||
gpu_bundles += 1
|
||||
if parallel_config.world_size > gpu_bundles:
|
||||
raise ValueError(
|
||||
"The number of tensor parallelism is not divisible by the "
|
||||
"number of GPUs per node.")
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the placement group.")
|
||||
else:
|
||||
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
|
||||
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
||||
if parallel_config.world_size > num_gpus_in_cluster:
|
||||
raise ValueError(
|
||||
"The number of GPUs per node is not divisible by the number "
|
||||
"of tensor parallelism.")
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the cluster.")
|
||||
# Create a new placement group
|
||||
current_placement_group = ray.util.placement_group([{
|
||||
"GPU": 1
|
||||
}] * parallel_config.world_size)
|
||||
# Wait until PG is ready - this will block until all
|
||||
# requested resources are available, and will timeout
|
||||
# if they cannot be provisioned.
|
||||
ray.get(current_placement_group.ready(), timeout=1800)
|
||||
|
||||
# Assign GPUs to pipeline stages.
|
||||
rank = 0
|
||||
current_node_id = 0
|
||||
current_device_id = 0
|
||||
distributed_init_method = None
|
||||
all_stage_devices = []
|
||||
|
||||
for _ in range(parallel_config.pipeline_parallel_size):
|
||||
stage_devices = []
|
||||
for _ in range(parallel_config.tensor_parallel_size):
|
||||
node_resource = valid_node_resources[current_node_id]
|
||||
stage_devices.append((rank, node_resource, current_device_id))
|
||||
if distributed_init_method is None:
|
||||
ip = node_resource.split("node:")[-1]
|
||||
port = get_open_port()
|
||||
distributed_init_method = f"tcp://{ip}:{port}"
|
||||
rank += 1
|
||||
current_device_id += 1
|
||||
if current_device_id >= num_devices_per_node:
|
||||
current_node_id += 1
|
||||
current_device_id = 0
|
||||
all_stage_devices.append(stage_devices)
|
||||
|
||||
return distributed_init_method, all_stage_devices
|
||||
return None, current_placement_group
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
"""A GPU worker class."""
|
||||
from typing import Dict, List, Tuple
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
@ -27,8 +29,8 @@ class Worker:
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
rank: Optional[int] = None,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
@ -36,19 +38,6 @@ class Worker:
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model = get_model(model_config)
|
||||
initialize_all_reduce_launcher(
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.get_hidden_size(),
|
||||
self.model_config.dtype,
|
||||
)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# self.init_cache_engine().
|
||||
self.cache_config = None
|
||||
@ -57,6 +46,31 @@ class Worker:
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_model(self):
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
# Env vars will be set by Ray.
|
||||
self.rank = self.rank if self.rank is not None else int(
|
||||
os.getenv("RANK", "-1"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
if self.rank < 0:
|
||||
raise ValueError("Invalid or unspecified rank.")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model = get_model(self.model_config)
|
||||
initialize_all_reduce_launcher(
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.get_hidden_size(),
|
||||
self.model_config.dtype,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_num_available_blocks(
|
||||
self,
|
||||
@ -294,15 +308,28 @@ class Worker:
|
||||
def _init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user