diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ce62282c2199..6d82714f3cc8 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -687,19 +687,30 @@ class AsyncMicrobatchTokenizer: max_length = kwargs.get("max_length") if not truncation: - return ("encode", add_special_tokens, False, None) + return "encode", add_special_tokens, False, None model_max = getattr(self.tokenizer, "model_max_length", None) if max_length is None or (model_max is not None and max_length == model_max): - return ("encode", add_special_tokens, True, "model_max") + return "encode", add_special_tokens, True, "model_max" - return ("encode", "other") + return "encode", "other" def __del__(self): - for task in self._batcher_tasks: - if not task.done(): - task.cancel() + if ((tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed()): + + def cancel_tasks(): + for task in tasks: + task.cancel() + + loop.call_soon_threadsafe(cancel_tasks) + + +def cancel_task_threadsafe(task: Task): + if task and not task.done() and not (loop := task.get_loop()).is_closed(): + loop.call_soon_threadsafe(task.cancel) def make_async( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7b4ed90fd132..a2706327914c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,7 +27,7 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cdiv, deprecate_kwargs +from vllm.utils import Device, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError @@ -219,8 +219,7 @@ class AsyncLLM(EngineClient): if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() - if handler := getattr(self, "output_handler", None): - handler.cancel() + cancel_task_threadsafe(getattr(self, "output_handler", None)) async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4d30bb6b7446..05b4d7260896 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,7 +23,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket +from vllm.utils import (cancel_task_threadsafe, get_open_port, + get_open_zmq_inproc_path, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, @@ -342,10 +343,8 @@ class BackgroundResources: if self.coordinator is not None: self.coordinator.close() - if self.output_queue_task is not None: - self.output_queue_task.cancel() - if self.stats_update_task is not None: - self.stats_update_task.cancel() + cancel_task_threadsafe(self.output_queue_task) + cancel_task_threadsafe(self.stats_update_task) # ZMQ context termination can hang if the sockets # aren't explicitly closed first.