[Core] Shut down aDAG workers with clean async llm engine exit (#7224)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao 2024-08-12 17:57:16 -07:00 committed by GitHub
parent 774cd1d3bf
commit 198d6a2898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 25 deletions

View File

@ -34,9 +34,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
USE_RAY_ADAG_NCCL = 0
USE_RAY_ADAG = 0
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -70,14 +67,13 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
pp_env = None
if USE_RAY_ADAG:
assert DIST_BACKEND == "ray", (
"Ray ADAG is only supported with Ray distributed backend")
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
# Test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
str(int(USE_RAY_ADAG_NCCL)),
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
}
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)

View File

@ -661,6 +661,20 @@ class AsyncLLMEngine:
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def shutdown_background_loop(self) -> None:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if self._background_loop_unshielded is not None:
self._background_loop_unshielded.cancel()
self._background_loop_unshielded = None
self.background_loop = None
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:

View File

@ -245,9 +245,18 @@ class LLMEngine:
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
@ -356,10 +365,10 @@ class LLMEngine:
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
),
))
@ -491,10 +500,6 @@ class LLMEngine:
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,

View File

@ -36,6 +36,7 @@ class AsyncEngineRPCServer:
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
self.engine.shutdown_background_loop()
async def get_model_config(self, identity):
"""Send the ModelConfig"""

View File

@ -60,6 +60,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
@ -117,7 +125,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
@ -446,11 +453,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
@ -523,8 +526,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return await asyncio.gather(*coros)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()