mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:06:18 +08:00
[Ray] Integration compiled DAG off by default (#2471)
This commit is contained in:
parent
931746bc6d
commit
65b89d16ee
@ -2,6 +2,7 @@ import copy
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
@ -30,6 +31,11 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
@ -124,6 +130,10 @@ class LLMEngine:
|
||||
self.stat_logger = StatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
@ -806,7 +816,8 @@ class LLMEngine:
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
@ -966,6 +977,7 @@ class LLMEngine:
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
@ -974,11 +986,16 @@ class LLMEngine:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
@ -991,6 +1008,37 @@ class LLMEngine:
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
if use_ray_compiled_dag:
|
||||
try:
|
||||
ray_worker_outputs = [
|
||||
pickle.loads(chan.begin_read())
|
||||
for chan in output_channels
|
||||
]
|
||||
finally:
|
||||
# Has to call end_read in order to reuse the DAG.
|
||||
for chan in output_channels:
|
||||
chan.end_read()
|
||||
else:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
current_version = pkg_resources.get_distribution("ray").version
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import MultiOutputNode, InputNode
|
||||
assert self.parallel_config.worker_use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
with InputNode() as input_data:
|
||||
forward_dag = MultiOutputNode([
|
||||
worker.execute_model_compiled_dag_remote.bind(input_data)
|
||||
for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import pickle
|
||||
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
@ -18,6 +20,11 @@ try:
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
# Since the compiled DAG runs a main execution
|
||||
# in a different thread that calls cuda.set_device.
|
||||
# The flag indicates is set_device is called on
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
self.worker = worker_init_fn()
|
||||
@ -40,6 +47,17 @@ try:
|
||||
def set_cuda_visible_devices(self, device_ids) -> None:
|
||||
set_cuda_visible_devices(device_ids)
|
||||
|
||||
def execute_model_compiled_dag_remote(self, ignored):
|
||||
"""Used only when compiled DAG is enabled."""
|
||||
import torch
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
output = self.worker.execute_model()
|
||||
output = pickle.dumps(output)
|
||||
return output
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user