diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 462d2c4e50e73..5e3374f9f6a10 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -149,6 +149,33 @@ class EngineClient(ABC): """Load a new LoRA adapter into the engine for future requests.""" ... + @abstractmethod + async def pause_generation( + self, + *, + wait_for_inflight_requests: bool = False, + clear_cache: bool = True, + ) -> None: + """Pause new generation/encoding requests. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight requests + to finish before pausing. When ``False`` (default), aborts in-flight + requests immediately. + clear_cache: Whether to clear KV and prefix caches after draining. + """ + ... + + @abstractmethod + async def resume_generation(self) -> None: + """Resume accepting generation/encoding requests.""" + ... + + @abstractmethod + async def is_paused(self) -> bool: + """Return whether the engine is currently paused.""" + ... + async def scale_elastic_ep( self, new_data_parallel_size: int, drain_timeout: int = 300 ) -> None: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3974f45a7135c..70174250ceabe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -394,6 +394,84 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) +@router.post("/pause") +async def pause_generation( + raw_request: Request, + wait_for_inflight_requests: bool = Query(False), + clear_cache: bool = Query(True), +) -> JSONResponse: + """Pause generation requests to allow weight updates. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + aborts any in-flight requests immediately. + clear_cache: Whether to clear KV/prefix caches after draining. + """ + + engine = engine_client(raw_request) + + try: + await engine.pause_generation( + wait_for_inflight_requests=wait_for_inflight_requests, + clear_cache=clear_cache, + ) + return JSONResponse( + content={"status": "paused"}, + status_code=HTTPStatus.OK.value, + ) + + except ValueError as err: + return JSONResponse( + content={"error": str(err)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to pause generation") + return JSONResponse( + content={"error": f"Failed to pause generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.post("/resume") +async def resume_generation(raw_request: Request) -> JSONResponse: + """Resume generation after a pause.""" + + engine = engine_client(raw_request) + + try: + await engine.resume_generation() + return JSONResponse( + content={"status": "resumed"}, + status_code=HTTPStatus.OK.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to resume generation") + return JSONResponse( + content={"error": f"Failed to resume generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.get("/is_paused") +async def is_paused(raw_request: Request) -> JSONResponse: + """Return the current pause status.""" + + engine = engine_client(raw_request) + + try: + paused = await engine.is_paused() + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to fetch pause status") + return JSONResponse( + content={"error": f"Failed to fetch pause status: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + return JSONResponse(content={"is_paused": paused}) + + @router.post( "/tokenize", dependencies=[Depends(validate_json_request)], diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index abf2c8cfa4539..c64b3cccfc652 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -152,6 +152,10 @@ class AsyncLLM(EngineClient): ) self.logger_manager.log_engine_initialized() + # Pause / resume state for async RL workflows. + self._pause_cond = asyncio.Condition() + self._paused = False + self.output_handler: asyncio.Task | None = None try: # Start output handler eagerly if we are in the asyncio eventloop. @@ -404,6 +408,10 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + # Wait until generation is resumed if the engine is paused. + async with self._pause_cond: + await self._pause_cond.wait_for(lambda: not self._paused) + if tokenization_kwargs is None: tokenization_kwargs = {} truncate_prompt_tokens = sampling_params.truncate_prompt_tokens @@ -551,6 +559,58 @@ class AsyncLLM(EngineClient): if self.log_requests: logger.info("Aborted request(s) %s.", ",".join(request_ids)) + async def pause_generation( + self, + *, + wait_for_inflight_requests: bool = False, + clear_cache: bool = True, + ) -> None: + """ + Pause generation to allow model weight updates. + + New generation/encoding requests are blocked until resume. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + immediately aborts any in-flight requests. + clear_cache: Whether to clear KV cache and prefix cache after + draining. Set to ``False`` to preserve cache for faster resume. + Default is ``True`` (clear caches). + """ + + async with self._pause_cond: + if self._paused: + return + self._paused = True + + if not wait_for_inflight_requests: + request_ids = list(self.output_processor.request_states.keys()) + if request_ids: + await self.abort(request_ids) + + # Wait for running requests to drain before clearing cache. + if self.output_processor.has_unfinished_requests(): + await self.output_processor.wait_for_requests_to_drain() + + # Clear cache + if clear_cache: + await self.reset_prefix_cache() + await self.reset_mm_cache() + + async def resume_generation(self) -> None: + """Resume generation after :meth:`pause_generation`.""" + + async with self._pause_cond: + self._paused = False + self._pause_cond.notify_all() # Wake up all waiting requests + + async def is_paused(self) -> bool: + """Return whether the engine is currently paused.""" + + async with self._pause_cond: + return self._paused + async def encode( self, prompt: PromptType, @@ -582,6 +642,10 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + # Respect pause state before accepting new requests. + async with self._pause_cond: + await self._pause_cond.wait_for(lambda: not self._paused) + if tokenization_kwargs is None: tokenization_kwargs = {} _validate_truncation_size( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index bdbbfe2595f81..0453c4a77f0cd 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -350,6 +350,8 @@ class OutputProcessor: self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates(log_stats) self.tracer: Tracer | None = None + self._requests_drained = asyncio.Event() + self._requests_drained.set() def get_num_unfinished_requests(self): return len(self.request_states) @@ -357,6 +359,11 @@ class OutputProcessor: def has_unfinished_requests(self) -> bool: return len(self.request_states) > 0 + async def wait_for_requests_to_drain(self) -> None: + if not self.request_states: + return + await self._requests_drained.wait() + def propagate_error(self, e: Exception): """Propagate error to all generate() tasks.""" @@ -396,6 +403,8 @@ class OutputProcessor: child_reqs = self.abort_requests(child_reqs) request_ids_to_abort.extend(child_reqs) self.parent_requests.pop(request_id, None) + if not self.request_states: + self._requests_drained.set() return request_ids_to_abort def add_request( @@ -420,6 +429,8 @@ class OutputProcessor: log_stats=self.log_stats, stream_interval=self.stream_interval, ) + if self._requests_drained.is_set(): + self._requests_drained.clear() self.request_states[request_id] = req_state if parent_req: self.parent_requests[parent_req.request_id] = parent_req @@ -511,6 +522,8 @@ class OutputProcessor: parent_req = req_state.parent_req if parent_req and not parent_req.child_requests: self.parent_requests.pop(parent_req.request_id, None) + if not self.request_states: + self._requests_drained.set() if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore.