mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 09:07:03 +08:00
[Core] Shut down aDAG workers with clean async llm engine exit (#7224)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
774cd1d3bf
commit
198d6a2898
@ -34,9 +34,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
|
||||
USE_RAY_ADAG_NCCL = 0
|
||||
USE_RAY_ADAG = 0
|
||||
|
||||
pp_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -70,14 +67,13 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||
pp_args.append("--enforce-eager")
|
||||
tp_args.append("--enforce-eager")
|
||||
pp_env = None
|
||||
if USE_RAY_ADAG:
|
||||
assert DIST_BACKEND == "ray", (
|
||||
"Ray ADAG is only supported with Ray distributed backend")
|
||||
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
|
||||
and CHUNKED_PREFILL):
|
||||
# Test Ray ADAG for a subset of the tests
|
||||
pp_env = {
|
||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||
"VLLM_USE_RAY_SPMD_WORKER": "1",
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
|
||||
str(int(USE_RAY_ADAG_NCCL)),
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
||||
}
|
||||
|
||||
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
|
||||
|
||||
@ -661,6 +661,20 @@ class AsyncLLMEngine:
|
||||
partial(_log_task_completion, error_callback=self._error_callback))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def shutdown_background_loop(self) -> None:
|
||||
"""
|
||||
Shut down the background loop.
|
||||
|
||||
This method needs to be called during cleanup to remove
|
||||
references to `self` and properly GC the resources held
|
||||
by the async LLM engine (e.g., the executors as well as
|
||||
their resources).
|
||||
"""
|
||||
if self._background_loop_unshielded is not None:
|
||||
self._background_loop_unshielded.cancel()
|
||||
self._background_loop_unshielded = None
|
||||
self.background_loop = None
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
if not self.engine_use_ray:
|
||||
|
||||
@ -245,9 +245,18 @@ class LLMEngine:
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
|
||||
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
||||
"make sure skip_tokenizer_init is False")
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
@ -356,10 +365,10 @@ class LLMEngine:
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
self.get_tokenizer_for_seq,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.get_tokenizer_for_seq,
|
||||
get_tokenizer_for_seq,
|
||||
),
|
||||
))
|
||||
|
||||
@ -491,10 +500,6 @@ class LLMEngine:
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(
|
||||
sequence.lora_request)
|
||||
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
|
||||
@ -36,6 +36,7 @@ class AsyncEngineRPCServer:
|
||||
"""Cleanup all resources."""
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
|
||||
@ -60,6 +60,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.forward_dag = None
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
# If nsight profiling is enabled, we need to set the profiling
|
||||
@ -117,7 +125,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
logger.info("driver_ip: %s", driver_ip)
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
@ -446,11 +453,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
|
||||
|
||||
def __del__(self):
|
||||
if self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
@ -523,8 +526,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def __del__(self):
|
||||
if self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.shutdown()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user