mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 07:57:04 +08:00
[RL] Add Pause and Resume Generation for Asynchronous RL Training (#28037)
Signed-off-by: SamitHuang <285365963@qq.com> Signed-off-by: Samit <285365963@qq.com> Signed-off-by: samithuang <285365963@qq.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
c9e093116c
commit
371b1d4c61
@ -149,6 +149,33 @@ class EngineClient(ABC):
|
|||||||
"""Load a new LoRA adapter into the engine for future requests."""
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def pause_generation(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
wait_for_inflight_requests: bool = False,
|
||||||
|
clear_cache: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Pause new generation/encoding requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait_for_inflight_requests: When ``True`` waits for in-flight requests
|
||||||
|
to finish before pausing. When ``False`` (default), aborts in-flight
|
||||||
|
requests immediately.
|
||||||
|
clear_cache: Whether to clear KV and prefix caches after draining.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def resume_generation(self) -> None:
|
||||||
|
"""Resume accepting generation/encoding requests."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def is_paused(self) -> bool:
|
||||||
|
"""Return whether the engine is currently paused."""
|
||||||
|
...
|
||||||
|
|
||||||
async def scale_elastic_ep(
|
async def scale_elastic_ep(
|
||||||
self, new_data_parallel_size: int, drain_timeout: int = 300
|
self, new_data_parallel_size: int, drain_timeout: int = 300
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -394,6 +394,84 @@ async def get_server_load_metrics(request: Request):
|
|||||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/pause")
|
||||||
|
async def pause_generation(
|
||||||
|
raw_request: Request,
|
||||||
|
wait_for_inflight_requests: bool = Query(False),
|
||||||
|
clear_cache: bool = Query(True),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Pause generation requests to allow weight updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||||
|
requests to finish before pausing. When ``False`` (default),
|
||||||
|
aborts any in-flight requests immediately.
|
||||||
|
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||||
|
"""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await engine.pause_generation(
|
||||||
|
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||||
|
clear_cache=clear_cache,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"status": "paused"},
|
||||||
|
status_code=HTTPStatus.OK.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as err:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(err)},
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
)
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to pause generation")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to pause generation: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/resume")
|
||||||
|
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||||
|
"""Resume generation after a pause."""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await engine.resume_generation()
|
||||||
|
return JSONResponse(
|
||||||
|
content={"status": "resumed"},
|
||||||
|
status_code=HTTPStatus.OK.value,
|
||||||
|
)
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to resume generation")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to resume generation: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/is_paused")
|
||||||
|
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||||
|
"""Return the current pause status."""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
paused = await engine.is_paused()
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to fetch pause status")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to fetch pause status: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(content={"is_paused": paused})
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/tokenize",
|
"/tokenize",
|
||||||
dependencies=[Depends(validate_json_request)],
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
|||||||
@ -152,6 +152,10 @@ class AsyncLLM(EngineClient):
|
|||||||
)
|
)
|
||||||
self.logger_manager.log_engine_initialized()
|
self.logger_manager.log_engine_initialized()
|
||||||
|
|
||||||
|
# Pause / resume state for async RL workflows.
|
||||||
|
self._pause_cond = asyncio.Condition()
|
||||||
|
self._paused = False
|
||||||
|
|
||||||
self.output_handler: asyncio.Task | None = None
|
self.output_handler: asyncio.Task | None = None
|
||||||
try:
|
try:
|
||||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||||
@ -404,6 +408,10 @@ class AsyncLLM(EngineClient):
|
|||||||
# to handle startup failure gracefully in the OpenAI server.
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
self._run_output_handler()
|
self._run_output_handler()
|
||||||
|
|
||||||
|
# Wait until generation is resumed if the engine is paused.
|
||||||
|
async with self._pause_cond:
|
||||||
|
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||||
|
|
||||||
if tokenization_kwargs is None:
|
if tokenization_kwargs is None:
|
||||||
tokenization_kwargs = {}
|
tokenization_kwargs = {}
|
||||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||||
@ -551,6 +559,58 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Aborted request(s) %s.", ",".join(request_ids))
|
logger.info("Aborted request(s) %s.", ",".join(request_ids))
|
||||||
|
|
||||||
|
async def pause_generation(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
wait_for_inflight_requests: bool = False,
|
||||||
|
clear_cache: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pause generation to allow model weight updates.
|
||||||
|
|
||||||
|
New generation/encoding requests are blocked until resume.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||||
|
requests to finish before pausing. When ``False`` (default),
|
||||||
|
immediately aborts any in-flight requests.
|
||||||
|
clear_cache: Whether to clear KV cache and prefix cache after
|
||||||
|
draining. Set to ``False`` to preserve cache for faster resume.
|
||||||
|
Default is ``True`` (clear caches).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self._pause_cond:
|
||||||
|
if self._paused:
|
||||||
|
return
|
||||||
|
self._paused = True
|
||||||
|
|
||||||
|
if not wait_for_inflight_requests:
|
||||||
|
request_ids = list(self.output_processor.request_states.keys())
|
||||||
|
if request_ids:
|
||||||
|
await self.abort(request_ids)
|
||||||
|
|
||||||
|
# Wait for running requests to drain before clearing cache.
|
||||||
|
if self.output_processor.has_unfinished_requests():
|
||||||
|
await self.output_processor.wait_for_requests_to_drain()
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
if clear_cache:
|
||||||
|
await self.reset_prefix_cache()
|
||||||
|
await self.reset_mm_cache()
|
||||||
|
|
||||||
|
async def resume_generation(self) -> None:
|
||||||
|
"""Resume generation after :meth:`pause_generation`."""
|
||||||
|
|
||||||
|
async with self._pause_cond:
|
||||||
|
self._paused = False
|
||||||
|
self._pause_cond.notify_all() # Wake up all waiting requests
|
||||||
|
|
||||||
|
async def is_paused(self) -> bool:
|
||||||
|
"""Return whether the engine is currently paused."""
|
||||||
|
|
||||||
|
async with self._pause_cond:
|
||||||
|
return self._paused
|
||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
@ -582,6 +642,10 @@ class AsyncLLM(EngineClient):
|
|||||||
# to handle startup failure gracefully in the OpenAI server.
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
self._run_output_handler()
|
self._run_output_handler()
|
||||||
|
|
||||||
|
# Respect pause state before accepting new requests.
|
||||||
|
async with self._pause_cond:
|
||||||
|
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||||
|
|
||||||
if tokenization_kwargs is None:
|
if tokenization_kwargs is None:
|
||||||
tokenization_kwargs = {}
|
tokenization_kwargs = {}
|
||||||
_validate_truncation_size(
|
_validate_truncation_size(
|
||||||
|
|||||||
@ -350,6 +350,8 @@ class OutputProcessor:
|
|||||||
self.parent_requests: dict[str, ParentRequest] = {}
|
self.parent_requests: dict[str, ParentRequest] = {}
|
||||||
self.lora_states = LoRARequestStates(log_stats)
|
self.lora_states = LoRARequestStates(log_stats)
|
||||||
self.tracer: Tracer | None = None
|
self.tracer: Tracer | None = None
|
||||||
|
self._requests_drained = asyncio.Event()
|
||||||
|
self._requests_drained.set()
|
||||||
|
|
||||||
def get_num_unfinished_requests(self):
|
def get_num_unfinished_requests(self):
|
||||||
return len(self.request_states)
|
return len(self.request_states)
|
||||||
@ -357,6 +359,11 @@ class OutputProcessor:
|
|||||||
def has_unfinished_requests(self) -> bool:
|
def has_unfinished_requests(self) -> bool:
|
||||||
return len(self.request_states) > 0
|
return len(self.request_states) > 0
|
||||||
|
|
||||||
|
async def wait_for_requests_to_drain(self) -> None:
|
||||||
|
if not self.request_states:
|
||||||
|
return
|
||||||
|
await self._requests_drained.wait()
|
||||||
|
|
||||||
def propagate_error(self, e: Exception):
|
def propagate_error(self, e: Exception):
|
||||||
"""Propagate error to all generate() tasks."""
|
"""Propagate error to all generate() tasks."""
|
||||||
|
|
||||||
@ -396,6 +403,8 @@ class OutputProcessor:
|
|||||||
child_reqs = self.abort_requests(child_reqs)
|
child_reqs = self.abort_requests(child_reqs)
|
||||||
request_ids_to_abort.extend(child_reqs)
|
request_ids_to_abort.extend(child_reqs)
|
||||||
self.parent_requests.pop(request_id, None)
|
self.parent_requests.pop(request_id, None)
|
||||||
|
if not self.request_states:
|
||||||
|
self._requests_drained.set()
|
||||||
return request_ids_to_abort
|
return request_ids_to_abort
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
@ -420,6 +429,8 @@ class OutputProcessor:
|
|||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
stream_interval=self.stream_interval,
|
stream_interval=self.stream_interval,
|
||||||
)
|
)
|
||||||
|
if self._requests_drained.is_set():
|
||||||
|
self._requests_drained.clear()
|
||||||
self.request_states[request_id] = req_state
|
self.request_states[request_id] = req_state
|
||||||
if parent_req:
|
if parent_req:
|
||||||
self.parent_requests[parent_req.request_id] = parent_req
|
self.parent_requests[parent_req.request_id] = parent_req
|
||||||
@ -511,6 +522,8 @@ class OutputProcessor:
|
|||||||
parent_req = req_state.parent_req
|
parent_req = req_state.parent_req
|
||||||
if parent_req and not parent_req.child_requests:
|
if parent_req and not parent_req.child_requests:
|
||||||
self.parent_requests.pop(parent_req.request_id, None)
|
self.parent_requests.pop(parent_req.request_id, None)
|
||||||
|
if not self.request_states:
|
||||||
|
self._requests_drained.set()
|
||||||
if not engine_core_output.finished:
|
if not engine_core_output.finished:
|
||||||
# If req not finished in EngineCore, but Detokenizer
|
# If req not finished in EngineCore, but Detokenizer
|
||||||
# detected stop string, abort needed in EngineCore.
|
# detected stop string, abort needed in EngineCore.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user