From f51f182d64fdecd06b0526bbb967ba0bfbec1053 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 23 Mar 2025 20:18:50 +0000 Subject: [PATCH] pre-commit Signed-off-by: rshaw@neuralmagic.com --- vllm/entrypoints/disaggregated/engine.py | 99 +++++++++++------------- 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/vllm/entrypoints/disaggregated/engine.py b/vllm/entrypoints/disaggregated/engine.py index 3081a62c4318e..7c047cb2ccdd2 100644 --- a/vllm/entrypoints/disaggregated/engine.py +++ b/vllm/entrypoints/disaggregated/engine.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import msgspec import os -from collections.abc import AsyncGenerator -from typing import Dict, List, Mapping, Optional +from collections.abc import AsyncGenerator, Mapping +from typing import Optional +import msgspec import zmq import zmq.asyncio -from vllm import SamplingParams from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse @@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (PoolingRequestOutput, RequestOutput, - CompletionOutput) +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -31,6 +29,7 @@ logger = init_logger(__name__) DEFAULT_MAX_TOKENS = 32000 + class PDEngine: """ PDEngine: @@ -42,15 +41,10 @@ class PDEngine: * TODO: move under vllm/v1/engine one past prototype. """ - def __init__( - self, - prefill_addr: str, - decode_addr: str, - connector_addr: str, - model_name: str - ): + def __init__(self, prefill_addr: str, decode_addr: str, + connector_addr: str, model_name: str): # Request queues. - self.queues: Dict[str, asyncio.Queue] = {} + self.queues: dict[str, asyncio.Queue] = {} # Serialization encoder. self.encoder = msgspec.msgpack.Encoder() @@ -64,22 +58,20 @@ class PDEngine: self.connector_addr = connector_addr self.decode_addr = decode_addr self.prefill_addr = prefill_addr - + # Background loops (started on first generate()). self.output_handler: Optional[asyncio.Task] = None self.log_running: Optional[asyncio.Task] = None # Dummy: needed for EngineClient Protocol. # TODO: refactor EngineClient to avoid needing this. - self.model_config = ModelConfig( - model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="auto", - task="generate", - seed=42 - ) + self.model_config = ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="auto", + task="generate", + seed=42) # Dummy: needed for EngineClient Protocol. # TODO: refactor EngineClient to avoid needing this. @@ -103,9 +95,7 @@ class PDEngine: if (task := self.output_handler) is not None: task.cancel() - ipc_paths = [ - self.connector_addr, self.decode_addr, self.prefill_addr - ] + ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr] for path in ipc_paths: socket_path = path.replace("ipc://", "") if os.path.exists(socket_path): @@ -121,7 +111,7 @@ class PDEngine: distribute back to the generate() tasks. """ decoder = msgspec.msgpack.Decoder(PDResponse) - + socket: Optional[zmq.asyncio.Socket] = None try: socket = self.ctx.socket(zmq.constants.PULL) @@ -134,11 +124,11 @@ class PDEngine: self.queues[response.request_id].put_nowait(response) except: # TODO: actually handle failure and shutdown. - raise + raise finally: if socket is not None: socket.close(linger=0) - + async def _prefill( self, request: PDRequest, @@ -154,7 +144,7 @@ class PDEngine: if not response.success: # TODO: actual error handling and shutdown. raise Exception("Failed Prefill Request.") - + async def _decode( self, request: PDRequest, @@ -169,17 +159,16 @@ class PDEngine: finished = False while not finished: response = await q.get() - logger.debug(f"{response}") if not response.success: # TODO: actual error handling and shutdown. raise Exception("Failed Decode Request.") finished = response.finish_reason is not None yield response - + def _to_request_output( self, pd_response: PDResponse, - prompt_token_ids: List[int], + prompt_token_ids: list[int], ) -> RequestOutput: finished = pd_response.finish_reason is not None return RequestOutput( @@ -187,15 +176,15 @@ class PDEngine: prompt=None, prompt_token_ids=prompt_token_ids, prompt_logprobs=None, - outputs=[CompletionOutput( - index=0, - text=pd_response.text, - token_ids=pd_response.token_ids, - cumulative_logprob=None, - logprobs=None, - finish_reason=pd_response.finish_reason, - stop_reason=pd_response.stop_reason - )], + outputs=[ + CompletionOutput(index=0, + text=pd_response.text, + token_ids=pd_response.token_ids, + cumulative_logprob=None, + logprobs=None, + finish_reason=pd_response.finish_reason, + stop_reason=pd_response.stop_reason) + ], finished=finished, ) @@ -211,28 +200,29 @@ class PDEngine: ) -> AsyncGenerator[RequestOutput]: # Start loops on first request. if self.output_handler is None: - self.output_handler = asyncio.create_task(self._run_output_handler()) + self.output_handler = asyncio.create_task( + self._run_output_handler()) self.log_running = asyncio.create_task(self._run_log_running()) # TODO: Expand to support the full matrix. - if not "prompt_token_ids" in prompt: + if "prompt_token_ids" not in prompt: raise NotImplementedError( "We currently only support TokensPrompt for P/D!") if lora_request is not None: raise NotImplementedError( - "We currently do not suppport LoRA for P/D!") + "We currently do not support LoRA for P/D!") if trace_headers is not None: raise NotImplementedError( - "We currently do not suppport tracing for P/D!") + "We currently do not support tracing for P/D!") if prompt_adapter_request is not None: raise NotImplementedError( - "We currently do not suppport prompt adapter for P/D!") + "We currently do not support prompt adapter for P/D!") if priority != 0: raise NotImplementedError( "We currently do not support priority for P/D!") if request_id in self.queues: raise ValueError(f"Found duplicate request_id: {request_id}!") - + # Queue to gather output from output_handler. q: asyncio.Queue[PDResponse] = asyncio.Queue() self.queues[request_id] = q @@ -240,10 +230,9 @@ class PDEngine: # (1) Perform the Prefill. original_max_tokens = sampling_params.max_tokens prompt_token_ids = prompt["prompt_token_ids"] - request = PDRequest( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) + request = PDRequest(request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) request.sampling_params.max_tokens = 1 logger.debug("Sending Prefill: %s", request.request_id) pd_response = await self._prefill(request, q) @@ -301,7 +290,7 @@ class PDEngine: async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, + model_output: Optional[list[SamplerOutput]] = None, ) -> None: pass @@ -325,7 +314,7 @@ class PDEngine: raise NotImplementedError async def is_sleeping(self) -> bool: - False + return False async def add_lora(self, lora_request: LoRARequest) -> None: raise NotImplementedError