mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 21:25:58 +08:00
[Ray] Integration compiled DAG off by default (#2471)
This commit is contained in:
parent
931746bc6d
commit
65b89d16ee
@ -2,6 +2,7 @@ import copy
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import pickle
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
@ -30,6 +31,11 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
|
# which optimizes the control plane overhead.
|
||||||
|
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
|
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
@ -124,6 +130,10 @@ class LLMEngine:
|
|||||||
self.stat_logger = StatLogger(
|
self.stat_logger = StatLogger(
|
||||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
|
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
|
||||||
|
|
||||||
|
self.forward_dag = None
|
||||||
|
if USE_RAY_COMPILED_DAG:
|
||||||
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
@ -806,7 +816,8 @@ class LLMEngine:
|
|||||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||||
})
|
},
|
||||||
|
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
output = all_outputs[0]
|
output = all_outputs[0]
|
||||||
@ -966,6 +977,7 @@ class LLMEngine:
|
|||||||
driver_args: Optional[List[Any]] = None,
|
driver_args: Optional[List[Any]] = None,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
|
use_ray_compiled_dag: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers."""
|
||||||
@ -974,6 +986,11 @@ class LLMEngine:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"max_concurrent_workers is not supported yet.")
|
"max_concurrent_workers is not supported yet.")
|
||||||
|
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
# Right now, compiled DAG can only accept a single
|
||||||
|
# input. TODO(sang): Fix it.
|
||||||
|
output_channels = self.forward_dag.execute(1)
|
||||||
|
else:
|
||||||
# Start the ray workers first.
|
# Start the ray workers first.
|
||||||
ray_worker_outputs = [
|
ray_worker_outputs = [
|
||||||
worker.execute_method.remote(method, *args, **kwargs)
|
worker.execute_method.remote(method, *args, **kwargs)
|
||||||
@ -991,6 +1008,37 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if self.workers:
|
if self.workers:
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
try:
|
||||||
|
ray_worker_outputs = [
|
||||||
|
pickle.loads(chan.begin_read())
|
||||||
|
for chan in output_channels
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
# Has to call end_read in order to reuse the DAG.
|
||||||
|
for chan in output_channels:
|
||||||
|
chan.end_read()
|
||||||
|
else:
|
||||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||||
|
|
||||||
return [driver_worker_output] + ray_worker_outputs
|
return [driver_worker_output] + ray_worker_outputs
|
||||||
|
|
||||||
|
def _compiled_ray_dag(self):
|
||||||
|
import pkg_resources
|
||||||
|
required_version = "2.9"
|
||||||
|
current_version = pkg_resources.get_distribution("ray").version
|
||||||
|
if current_version < required_version:
|
||||||
|
raise ValueError(f"Ray version {required_version} or greater is "
|
||||||
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
|
from ray.dag import MultiOutputNode, InputNode
|
||||||
|
assert self.parallel_config.worker_use_ray
|
||||||
|
|
||||||
|
# Right now, compiled DAG requires at least 1 arg. We send
|
||||||
|
# a dummy value for now. It will be fixed soon.
|
||||||
|
with InputNode() as input_data:
|
||||||
|
forward_dag = MultiOutputNode([
|
||||||
|
worker.execute_model_compiled_dag_remote.bind(input_data)
|
||||||
|
for worker in self.workers
|
||||||
|
])
|
||||||
|
return forward_dag.experimental_compile()
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
@ -18,6 +20,11 @@ try:
|
|||||||
from transformers.dynamic_module_utils import init_hf_modules
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
init_hf_modules()
|
init_hf_modules()
|
||||||
self.worker = None
|
self.worker = None
|
||||||
|
# Since the compiled DAG runs a main execution
|
||||||
|
# in a different thread that calls cuda.set_device.
|
||||||
|
# The flag indicates is set_device is called on
|
||||||
|
# that thread.
|
||||||
|
self.compiled_dag_cuda_device_set = False
|
||||||
|
|
||||||
def init_worker(self, worker_init_fn):
|
def init_worker(self, worker_init_fn):
|
||||||
self.worker = worker_init_fn()
|
self.worker = worker_init_fn()
|
||||||
@ -40,6 +47,17 @@ try:
|
|||||||
def set_cuda_visible_devices(self, device_ids) -> None:
|
def set_cuda_visible_devices(self, device_ids) -> None:
|
||||||
set_cuda_visible_devices(device_ids)
|
set_cuda_visible_devices(device_ids)
|
||||||
|
|
||||||
|
def execute_model_compiled_dag_remote(self, ignored):
|
||||||
|
"""Used only when compiled DAG is enabled."""
|
||||||
|
import torch
|
||||||
|
if not self.compiled_dag_cuda_device_set:
|
||||||
|
torch.cuda.set_device(self.worker.device)
|
||||||
|
self.compiled_dag_cuda_device_set = True
|
||||||
|
|
||||||
|
output = self.worker.execute_model()
|
||||||
|
output = pickle.dumps(output)
|
||||||
|
return output
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||||
"For distributed inference, please install Ray with "
|
"For distributed inference, please install Ray with "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user