mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 06:44:30 +08:00
Add distributed model executor abstraction (#3191)
This commit is contained in:
parent
657061fdce
commit
4c922709b6
@ -2,5 +2,5 @@ LLMEngine
|
||||
=================================
|
||||
|
||||
.. autoclass:: vllm.engine.llm_engine.LLMEngine
|
||||
:members: add_request, abort_request, step, _init_cache
|
||||
:members: add_request, abort_request, step
|
||||
:show-inheritance:
|
||||
@ -95,13 +95,17 @@ echo 'vLLM yapf: Done'
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
|
||||
CODESPELL_EXCLUDES=(
|
||||
'--skip' '*docs/source/_build/**'
|
||||
)
|
||||
|
||||
# check spelling of specified files
|
||||
spell_check() {
|
||||
codespell "$@"
|
||||
}
|
||||
|
||||
spell_check_all(){
|
||||
codespell --toml pyproject.toml
|
||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
|
||||
}
|
||||
|
||||
# Spelling check of files that differ from main branch.
|
||||
@ -116,7 +120,7 @@ spell_check_changed() {
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
codespell
|
||||
codespell "${CODESPELL_EXCLUDES[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
@ -152,4 +152,5 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
@pytest.fixture
|
||||
def llama_2_7b_model_extra_embeddings(
|
||||
llama_2_7b_engine_extra_embeddings) -> nn.Module:
|
||||
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
|
||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -19,5 +19,5 @@ __all__ = [
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_cluster",
|
||||
"initialize_ray_cluster",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union, ClassVar
|
||||
from typing import TYPE_CHECKING, Optional, Union, ClassVar
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from packaging.version import Version
|
||||
@ -10,6 +10,9 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GB = 1 << 30
|
||||
@ -397,6 +400,7 @@ class ParallelConfig:
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
if is_neuron():
|
||||
@ -412,6 +416,7 @@ class ParallelConfig:
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
self.ray_workers_use_nsight = ray_workers_use_nsight
|
||||
self.placement_group = placement_group
|
||||
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
# Ray worker is not supported for Neuron backend.
|
||||
|
||||
@ -2,8 +2,8 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator, Callable)
|
||||
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -208,17 +208,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = await self._run_workers_async(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
output = await self.model_executor.execute_model_async(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
else:
|
||||
output = []
|
||||
|
||||
@ -268,37 +261,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Run the driver worker asynchronously.
|
||||
driver_executor = getattr(self.driver_worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(driver_executor, *driver_args, **driver_kwargs)))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
|
||||
async def check_health_async(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
async def check_health_async(self) -> None:
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@ -353,6 +317,34 @@ class AsyncLLMEngine:
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
executor_class = GPUExecutorAsync
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
@ -670,35 +662,13 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
placement_group = initialize_cluster(parallel_config,
|
||||
engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote()
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
async def check_health(self):
|
||||
async def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
t = time.perf_counter()
|
||||
logger.debug("Starting health check...")
|
||||
|
||||
@ -1,11 +1,5 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
import importlib
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -15,8 +9,9 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -24,29 +19,11 @@ from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
TokenizerGroup)
|
||||
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
|
||||
get_open_port, get_distributed_init_method)
|
||||
|
||||
if ray:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
# A map between the device type (in device config) to its worker module.
|
||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||
"cuda": "vllm.worker.worker",
|
||||
"neuron": "vllm.worker.neuron_worker",
|
||||
}
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
@ -71,8 +48,8 @@ class LLMEngine:
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
log_stats: Whether to log statistics.
|
||||
"""
|
||||
|
||||
@ -84,7 +61,7 @@ class LLMEngine:
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
@ -121,33 +98,13 @@ class LLMEngine:
|
||||
self._init_tokenizer()
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
if self.parallel_config.worker_use_ray:
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
# Pass additional arguments to initialize the worker
|
||||
additional_ray_args = {}
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
logger.info("Configuring Ray workers to use nsight.")
|
||||
additional_ray_args = {
|
||||
"runtime_env": {
|
||||
"nsight": {
|
||||
"t": "cuda,cudnn,cublas",
|
||||
"o": "'worker_process_%p'",
|
||||
"cuda-graph-trace": "node",
|
||||
}
|
||||
}
|
||||
}
|
||||
self._init_workers_ray(placement_group, **additional_ray_args)
|
||||
else:
|
||||
self._init_workers()
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
self.model_executor = executor_class(model_config, cache_config,
|
||||
parallel_config, scheduler_config,
|
||||
device_config, lora_config)
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
|
||||
# Metric Logging.
|
||||
@ -157,9 +114,29 @@ class LLMEngine:
|
||||
labels=dict(model_name=model_config.model))
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
|
||||
# Initialize the cluster and specify the executor class.
|
||||
if parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||
executor_class = RayGPUExecutor
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
def __reduce__(self):
|
||||
# This is to ensure that the LLMEngine is not referenced in
|
||||
@ -173,39 +150,6 @@ class LLMEngine:
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def _dispatch_worker(self):
|
||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||
self.device_config.device_type]
|
||||
imported_worker = importlib.import_module(worker_module)
|
||||
Worker = imported_worker.Worker
|
||||
return Worker
|
||||
|
||||
def _init_workers(self):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
Worker = self._dispatch_worker()
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self._run_workers("init_model")
|
||||
self._run_workers("load_model")
|
||||
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
||||
init_kwargs = dict(
|
||||
enable_lora=bool(self.lora_config),
|
||||
@ -218,126 +162,6 @@ class LLMEngine:
|
||||
self.tokenizer: TokenizerGroup = TokenizerGroup(
|
||||
self.model_config.tokenizer, **init_kwargs)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
|
||||
self.driver_dummy_worker: RayWorkerVllm = None
|
||||
self.workers: List[RayWorkerVllm] = []
|
||||
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
else:
|
||||
self.workers.append(worker)
|
||||
|
||||
if self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
driver_node_id, driver_gpu_ids = ray.get(
|
||||
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
||||
worker_node_and_gpu_ids = ray.get(
|
||||
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
||||
|
||||
node_workers = defaultdict(list)
|
||||
node_gpus = defaultdict(list)
|
||||
|
||||
node_workers[driver_node_id].append(0)
|
||||
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
||||
start=1):
|
||||
node_workers[node_id].append(i)
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver.
|
||||
set_cuda_visible_devices(node_gpus[driver_node_id])
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
Worker = self._dispatch_worker()
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
kv_cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
worker_node_and_gpu_ids),
|
||||
start=1):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
worker.init_worker.remote(
|
||||
lambda rank=rank, local_rank=local_rank: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# don't use cupy for eager mode
|
||||
self._run_workers("init_model",
|
||||
cupy_port=get_open_port()
|
||||
if not model_config.enforce_eager else None)
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -346,81 +170,6 @@ class LLMEngine:
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
More details can be found in the
|
||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
||||
from class :class:`~vllm.worker.Worker`.
|
||||
|
||||
Afterwards, as there may be multiple workers,
|
||||
we take the minimum number of blocks across all workers
|
||||
to ensure this can be applied to all of them.
|
||||
|
||||
Finally, the engine will initialize the KV cache
|
||||
with the calculated number of blocks.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameters.
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
# FIXME(woosuk): Change to debug log.
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self._run_workers("warm_up_model")
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
placement_group = initialize_cluster(parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs,
|
||||
placement_group,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
def encode_request(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
@ -826,7 +575,7 @@ class LLMEngine:
|
||||
- A Sequence Group (SG) refer to a group of sequences
|
||||
that are generated from the same prompt.
|
||||
|
||||
- Step 2: Calls the workers to execute the model.
|
||||
- Step 2: Calls the distributed executor to execute the model.
|
||||
- Step 3: Processes the model output. This mainly includes:
|
||||
|
||||
- Decodes the relevant outputs.
|
||||
@ -862,19 +611,10 @@ class LLMEngine:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
output = self.model_executor.execute_model(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
else:
|
||||
output = []
|
||||
|
||||
@ -1043,111 +783,13 @@ class LLMEngine:
|
||||
seq.output_text = seq.output_text[:-len(stop_string)]
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"add_lora",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"remove_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
return self.model_executor.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = getattr(self.driver_worker,
|
||||
method)(*driver_args, **driver_kwargs)
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
if use_ray_compiled_dag:
|
||||
try:
|
||||
ray_worker_outputs = [
|
||||
pickle.loads(chan.begin_read())
|
||||
for chan in output_channels
|
||||
]
|
||||
finally:
|
||||
# Has to call end_read in order to reuse the DAG.
|
||||
for chan in output_channels:
|
||||
chan.end_read()
|
||||
else:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
current_version = pkg_resources.get_distribution("ray").version
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import MultiOutputNode, InputNode
|
||||
assert self.parallel_config.worker_use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
with InputNode() as input_data:
|
||||
forward_dag = MultiOutputNode([
|
||||
worker.execute_model_compiled_dag_remote.bind(input_data)
|
||||
for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
return self.model_executor.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
def _check_if_any_actor_is_dead(self):
|
||||
if not self.parallel_config.worker_use_ray:
|
||||
return
|
||||
|
||||
if not self.workers:
|
||||
return
|
||||
|
||||
dead_actors = []
|
||||
for actor in self.workers:
|
||||
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||
if actor_state["State"] == "DEAD":
|
||||
dead_actors.append(actor)
|
||||
if dead_actors:
|
||||
raise RuntimeError("At least one Worker is dead. "
|
||||
f"Dead Workers: {dead_actors}. ")
|
||||
self.model_executor.check_health()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pickle
|
||||
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -65,45 +65,38 @@ except ImportError as e:
|
||||
ray = None
|
||||
RayWorkerVllm = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
def initialize_ray_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
ray_address: Optional[str] = None,
|
||||
) -> Optional["PlacementGroup"]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
):
|
||||
"""Initialize the distributed cluster with Ray.
|
||||
|
||||
it will connect to the Ray cluster and create a placement group
|
||||
for the workers, which includes the specification of the resources
|
||||
for each distributed worker.
|
||||
|
||||
Args:
|
||||
parallel_config: The configurations for parallel execution.
|
||||
engine_use_ray: Whether to use Ray for async engine.
|
||||
ray_address: The address of the Ray cluster. If None, uses
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
An optional `PlacementGroup`. It includes the specification
|
||||
of the resources for each distributed worker. None if Ray is
|
||||
not used.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
return None
|
||||
# Connect to a ray cluster.
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if parallel_config.placement_group:
|
||||
# Placement group is already set.
|
||||
return
|
||||
|
||||
# Create placement group for worker processes
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
@ -138,4 +131,5 @@ def initialize_cluster(
|
||||
# if they cannot be provisioned.
|
||||
ray.get(current_placement_group.ready(), timeout=1800)
|
||||
|
||||
return current_placement_group
|
||||
# Set the placement group in the parallel config
|
||||
parallel_config.placement_group = current_placement_group
|
||||
|
||||
0
vllm/executor/__init__.py
Normal file
0
vllm/executor/__init__.py
Normal file
75
vllm/executor/executor_base.py
Normal file
75
vllm/executor/executor_base.py
Normal file
@ -0,0 +1,75 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
|
||||
class ExecutorBase(ABC):
|
||||
"""Base class for all executors.
|
||||
|
||||
An executor is responsible for executing the model on a specific device
|
||||
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
|
||||
that can execute the model on multiple devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> List[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ExecutorAsyncBase(ExecutorBase):
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> SamplerOutput:
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def check_health_async(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
raise NotImplementedError
|
||||
163
vllm/executor/gpu_executor.py
Normal file
163
vllm/executor/gpu_executor.py
Normal file
@ -0,0 +1,163 @@
|
||||
import importlib
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.utils import check_block_size_valid
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import (get_ip, get_open_port, get_distributed_init_method,
|
||||
make_async)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A map between the device type (in device config) to its worker module.
|
||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||
"cuda": "vllm.worker.worker",
|
||||
"neuron": "vllm.worker.neuron_worker",
|
||||
}
|
||||
|
||||
|
||||
class GPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
self._init_worker()
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
def _dispatch_worker(self):
|
||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||
self.device_config.device_type]
|
||||
imported_worker = importlib.import_module(worker_module)
|
||||
Worker = imported_worker.Worker
|
||||
return Worker
|
||||
|
||||
def _init_worker(self):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
Worker = self._dispatch_worker()
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self.driver_worker.init_model()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
|
||||
The engine first profiles the existing memory usage.
|
||||
Then, it allocates the remaining memory for KV blocks.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_gpu_blocks, num_cpu_blocks = (
|
||||
self.driver_worker.profile_num_available_blocks(
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.
|
||||
gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
))
|
||||
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
||||
self.model_config.max_model_len)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self.driver_worker.warm_up_model()
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
|
||||
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> SamplerOutput:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy)
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
442
vllm/executor/ray_gpu_executor.py
Normal file
442
vllm/executor/ray_gpu_executor.py
Normal file
@ -0,0 +1,442 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import pickle
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.utils import check_block_size_valid
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port,
|
||||
get_distributed_init_method, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A map between the device type (in device config) to its worker module.
|
||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||
"cuda": "vllm.worker.worker",
|
||||
"neuron": "vllm.worker.neuron_worker",
|
||||
}
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
|
||||
|
||||
class RayGPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
|
||||
assert self.parallel_config.worker_use_ray
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
|
||||
def _dispatch_worker(self):
|
||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||
self.device_config.device_type]
|
||||
imported_worker = importlib.import_module(worker_module)
|
||||
Worker = imported_worker.Worker
|
||||
return Worker
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: RayWorkerVllm = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerVllm] = []
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
if self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
driver_node_id, driver_gpu_ids = ray.get(
|
||||
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
||||
worker_node_and_gpu_ids = ray.get(
|
||||
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
||||
|
||||
node_workers = defaultdict(list)
|
||||
node_gpus = defaultdict(list)
|
||||
|
||||
node_workers[driver_node_id].append(0)
|
||||
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
||||
start=1):
|
||||
node_workers[node_id].append(i)
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
||||
set_cuda_visible_devices(node_gpus[driver_node_id])
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
Worker = self._dispatch_worker()
|
||||
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
kv_cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
# Initialize the actual workers with the Worker class.
|
||||
for rank, (worker, (node_id, _)) in enumerate(
|
||||
zip(self.workers, worker_node_and_gpu_ids),
|
||||
start=1,
|
||||
):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
worker.init_worker.remote(
|
||||
lambda rank=rank, local_rank=local_rank: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
))
|
||||
|
||||
# Initialize the driver worker with the Worker class.
|
||||
driver_rank = 0
|
||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
||||
# we have multiple nodes.
|
||||
self._run_workers("init_model",
|
||||
cupy_port=get_open_port()
|
||||
if not model_config.enforce_eager else None)
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
More details can be found in the
|
||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
||||
from class :class:`~vllm.worker.Worker`.
|
||||
|
||||
Afterwards, as there may be multiple workers,
|
||||
we take the minimum number of blocks across all workers
|
||||
to ensure this can be applied to all of them.
|
||||
|
||||
Finally, the engine will initialize the KV cache
|
||||
with the calculated number of blocks.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
||||
self.model_config.max_model_len)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self._run_workers("warm_up_model")
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"add_lora",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"remove_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = getattr(self.driver_worker,
|
||||
method)(*driver_args, **driver_kwargs)
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
if use_ray_compiled_dag:
|
||||
try:
|
||||
ray_worker_outputs = [
|
||||
pickle.loads(chan.begin_read())
|
||||
for chan in output_channels
|
||||
]
|
||||
finally:
|
||||
# Has to call end_read in order to reuse the DAG.
|
||||
for chan in output_channels:
|
||||
chan.end_read()
|
||||
else:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
current_version = pkg_resources.get_distribution("ray").version
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import MultiOutputNode, InputNode
|
||||
assert self.parallel_config.worker_use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
with InputNode() as input_data:
|
||||
forward_dag = MultiOutputNode([
|
||||
worker.execute_model_compiled_dag_remote.bind(input_data)
|
||||
for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
def _check_if_any_actor_is_dead(self):
|
||||
if not self.workers:
|
||||
return
|
||||
|
||||
dead_actors = []
|
||||
for actor in self.workers:
|
||||
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||
if actor_state["State"] == "DEAD":
|
||||
dead_actors.append(actor)
|
||||
if dead_actors:
|
||||
raise RuntimeError("At least one Worker is dead. "
|
||||
f"Dead Workers: {dead_actors}. ")
|
||||
|
||||
|
||||
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Run the driver worker asynchronously.
|
||||
driver_executor = make_async(getattr(self.driver_worker, method))
|
||||
coros.append(driver_executor(*driver_args, **driver_kwargs))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> SamplerOutput:
|
||||
all_outputs = await self._run_workers_async(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
13
vllm/executor/utils.py
Normal file
13
vllm/executor/utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * num_gpu_blocks
|
||||
if max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
Loading…
x
Reference in New Issue
Block a user