From ad0297d1139f55cbd602652afd54276a8ae217ce Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 15 Aug 2025 17:00:36 -0700 Subject: [PATCH] [Misc] Support passing multiple request ids at once to `AsyncLLM.abort()` (#22944) Signed-off-by: Nick Hill --- tests/v1/engine/test_async_llm.py | 77 ++++++++++++++++++++++++++- vllm/engine/async_llm_engine.py | 5 +- vllm/engine/multiprocessing/client.py | 10 ++-- vllm/engine/protocol.py | 7 +-- vllm/utils/__init__.py | 5 ++ vllm/v1/engine/async_llm.py | 15 +++--- 6 files changed, 105 insertions(+), 14 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 484640233f52..df04a14af70c 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -212,6 +212,79 @@ async def test_abort( assert not engine.output_processor.has_unfinished_requests() +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_multi_abort( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + NUM_REQUESTS = 50 + NUM_EXPECTED_TOKENS = 100 + NUM_EXPECTED_TOKENS_LONG = 50000 + REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25] + PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35] + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: list[asyncio.Task] = [] + for idx, request_id in enumerate(request_ids): + max_tokens = (NUM_EXPECTED_TOKENS_LONG if + (idx + in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + tasks.append( + asyncio.create_task( + generate(engine, request_id, TEXT_PROMPT, output_kind, + max_tokens, n))) + + # Let requests start + await asyncio.sleep(0.5) + + # Use multi-abort to abort multiple requests at once + abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT] + await engine.abort(abort_request_ids) + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify results + for idx, result in enumerate(results): + if idx in REQUEST_IDS_TO_ABORT: + # Aborted requests should return partial results + assert isinstance( + result, tuple + ), f"Request {idx} should have completed with partial results" + num_generated_tokens, request_id = result + # Should have generated some tokens before abort + assert num_generated_tokens > 0, ( + f"Aborted request " + f"{request_id} should have generated some tokens") + else: + # Non-aborted requests should complete normally + assert isinstance( + result, + tuple), f"Request {idx} should have completed successfully" + num_generated_tokens, request_id = result + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + expected_tokens = NUM_EXPECTED_TOKENS * n + assert num_generated_tokens == expected_tokens, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {expected_tokens}") + + # Make sure all aborted requests were cleaned up + assert not engine.output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize( "engine_args,prompt", @@ -460,7 +533,9 @@ async def test_abort_final_output( token_count = sum( len(output.outputs[0].token_ids) for output in outputs) assert token_count > 0 - assert len(final_output.outputs[0].token_ids) == 0 + # This would ordinarily be 0, but could end up > 0 if the + # final abort is coalesced with another chunk in the output queue. + assert len(final_output.outputs[0].token_ids) >= 0 else: # For FINAL_ONLY, we should only get the final output assert len(outputs) == 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 73726eeab5fc..84ad2299b065 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -998,7 +998,7 @@ class AsyncLLMEngine(EngineClient): await self.abort(request_id) raise - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, @@ -1007,6 +1007,9 @@ class AsyncLLMEngine(EngineClient): Args: request_id: The unique id of the request. """ + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") if not self.is_running: raise AsyncEngineDeadError( "Background loop is not running. If it was running, " diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f69f72edf6a5..eca29af50055 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -5,8 +5,8 @@ import asyncio import copy import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, cast) +from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, + Mapping, Optional, Union, cast) import cloudpickle import psutil @@ -404,9 +404,13 @@ class MQLLMEngineClient(EngineClient): error_message="Unable to start RPC Server", socket=socket) - async def abort(self, request_id: str): + async def abort(self, request_id: Union[str, Iterable[str]]): """Send an ABORT_REQUEST signal to the RPC Server""" + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") + with suppress(MQClientClosedError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 671e9648a3d0..c610fb5eae60 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, Mapping, Optional +from typing import AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig, VllmConfig @@ -229,11 +229,12 @@ class EngineClient(ABC): ... @abstractmethod - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Args: - request_id: The unique id of the request. + request_id: The unique id of the request, + or an iterable of such ids. """ ... diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 40f41893abb6..64f7426bd65d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1315,6 +1315,11 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): ) +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + # `collections` helpers def is_list_of( value: object, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index edc2e235c3c3..664fec31a4da 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import time -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy from typing import Any, Optional, Union @@ -27,7 +27,8 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cancel_task_threadsafe, cdiv, deprecate_kwargs +from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, + deprecate_kwargs) from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError @@ -431,14 +432,16 @@ class AsyncLLM(EngineClient): self.output_handler = asyncio.create_task(output_handler()) - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = self.output_processor.abort_requests((request_id, )) - await self.engine_core.abort_requests_async(request_ids) + request_ids = (request_id, ) if isinstance( + request_id, str) else as_list(request_id) + all_request_ids = self.output_processor.abort_requests(request_ids) + await self.engine_core.abort_requests_async(all_request_ids) if self.log_requests: - logger.info("Aborted request %s.", request_id) + logger.info("Aborted request(s) %s.", ",".join(request_ids)) async def encode( self,