diff --git a/tests/conftest.py b/tests/conftest.py index 281c9161d301..686303e6bebb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,8 +156,8 @@ class VllmRunner: ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params) - return [(output_ids[0], output_str[0]) for output_ids, output_str in - outputs] + return [(output_ids[0], output_str[0]) + for output_ids, output_str in outputs] def generate_beam_search( self, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cb987bd64d80..a706650eeccb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,7 +1,7 @@ import asyncio import time from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -152,7 +152,7 @@ class AsyncLLMEngine: # Request id -> stream. self.request_streams: Dict[str, AsyncStream] = {} - self.finished_requests: Set[str] = set() + self.finished_requests: asyncio.Queue[str] = asyncio.Queue() self.background_loop = None if start_engine_loop: self.start_background_loop() @@ -194,12 +194,14 @@ class AsyncLLMEngine: if self.log_requests: logger.info(f"Finished request {request_id}.") self.request_streams[request_id].finish() - self.finished_requests.add(request_id) + self.finished_requests.put_nowait(request_id) - await self._engine_abort(self.finished_requests) - for request_id in self.finished_requests: + finished_request = set() + while not self.finished_requests.empty(): + finished_request.add(self.finished_requests.get_nowait()) + await self._engine_abort(finished_request) + for request_id in finished_request: del self.request_streams[request_id] - self.finished_requests.clear() async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -226,6 +228,8 @@ class AsyncLLMEngine: f"sampling params: {sampling_params}, " f"prompt token ids: {prompt_token_ids}.") + if request_id in self.request_streams: + raise KeyError(f"Request {request_id} already exists.") stream = AsyncStream(request_id) self.request_streams[request_id] = stream @@ -316,7 +320,7 @@ class AsyncLLMEngine: logger.info(f"Aborted request {request_id}.") self.request_streams[request_id].finish() - self.finished_requests.add(request_id) + self.finished_requests.put_nowait(request_id) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine."""