pre-commit

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 20:18:50 +00:00
parent 79e465f557
commit f51f182d64

View File

@ -1,15 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import msgspec
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Mapping
from typing import Dict, List, Mapping, Optional from typing import Optional
import msgspec
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
CompletionOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -31,6 +29,7 @@ logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000 DEFAULT_MAX_TOKENS = 32000
class PDEngine: class PDEngine:
""" """
PDEngine: PDEngine:
@ -42,15 +41,10 @@ class PDEngine:
* TODO: move under vllm/v1/engine one past prototype. * TODO: move under vllm/v1/engine one past prototype.
""" """
def __init__( def __init__(self, prefill_addr: str, decode_addr: str,
self, connector_addr: str, model_name: str):
prefill_addr: str,
decode_addr: str,
connector_addr: str,
model_name: str
):
# Request queues. # Request queues.
self.queues: Dict[str, asyncio.Queue] = {} self.queues: dict[str, asyncio.Queue] = {}
# Serialization encoder. # Serialization encoder.
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
@ -64,22 +58,20 @@ class PDEngine:
self.connector_addr = connector_addr self.connector_addr = connector_addr
self.decode_addr = decode_addr self.decode_addr = decode_addr
self.prefill_addr = prefill_addr self.prefill_addr = prefill_addr
# Background loops (started on first generate()). # Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor EngineClient to avoid needing this.
self.model_config = ModelConfig( self.model_config = ModelConfig(model=model_name,
model=model_name, tokenizer=model_name,
tokenizer=model_name, tokenizer_mode="auto",
tokenizer_mode="auto", trust_remote_code=False,
trust_remote_code=False, dtype="auto",
dtype="auto", task="generate",
task="generate", seed=42)
seed=42
)
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor EngineClient to avoid needing this.
@ -103,9 +95,7 @@ class PDEngine:
if (task := self.output_handler) is not None: if (task := self.output_handler) is not None:
task.cancel() task.cancel()
ipc_paths = [ ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
self.connector_addr, self.decode_addr, self.prefill_addr
]
for path in ipc_paths: for path in ipc_paths:
socket_path = path.replace("ipc://", "") socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path): if os.path.exists(socket_path):
@ -121,7 +111,7 @@ class PDEngine:
distribute back to the generate() tasks. distribute back to the generate() tasks.
""" """
decoder = msgspec.msgpack.Decoder(PDResponse) decoder = msgspec.msgpack.Decoder(PDResponse)
socket: Optional[zmq.asyncio.Socket] = None socket: Optional[zmq.asyncio.Socket] = None
try: try:
socket = self.ctx.socket(zmq.constants.PULL) socket = self.ctx.socket(zmq.constants.PULL)
@ -134,11 +124,11 @@ class PDEngine:
self.queues[response.request_id].put_nowait(response) self.queues[response.request_id].put_nowait(response)
except: except:
# TODO: actually handle failure and shutdown. # TODO: actually handle failure and shutdown.
raise raise
finally: finally:
if socket is not None: if socket is not None:
socket.close(linger=0) socket.close(linger=0)
async def _prefill( async def _prefill(
self, self,
request: PDRequest, request: PDRequest,
@ -154,7 +144,7 @@ class PDEngine:
if not response.success: if not response.success:
# TODO: actual error handling and shutdown. # TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.") raise Exception("Failed Prefill Request.")
async def _decode( async def _decode(
self, self,
request: PDRequest, request: PDRequest,
@ -169,17 +159,16 @@ class PDEngine:
finished = False finished = False
while not finished: while not finished:
response = await q.get() response = await q.get()
logger.debug(f"{response}")
if not response.success: if not response.success:
# TODO: actual error handling and shutdown. # TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.") raise Exception("Failed Decode Request.")
finished = response.finish_reason is not None finished = response.finish_reason is not None
yield response yield response
def _to_request_output( def _to_request_output(
self, self,
pd_response: PDResponse, pd_response: PDResponse,
prompt_token_ids: List[int], prompt_token_ids: list[int],
) -> RequestOutput: ) -> RequestOutput:
finished = pd_response.finish_reason is not None finished = pd_response.finish_reason is not None
return RequestOutput( return RequestOutput(
@ -187,15 +176,15 @@ class PDEngine:
prompt=None, prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, prompt_logprobs=None,
outputs=[CompletionOutput( outputs=[
index=0, CompletionOutput(index=0,
text=pd_response.text, text=pd_response.text,
token_ids=pd_response.token_ids, token_ids=pd_response.token_ids,
cumulative_logprob=None, cumulative_logprob=None,
logprobs=None, logprobs=None,
finish_reason=pd_response.finish_reason, finish_reason=pd_response.finish_reason,
stop_reason=pd_response.stop_reason stop_reason=pd_response.stop_reason)
)], ],
finished=finished, finished=finished,
) )
@ -211,28 +200,29 @@ class PDEngine:
) -> AsyncGenerator[RequestOutput]: ) -> AsyncGenerator[RequestOutput]:
# Start loops on first request. # Start loops on first request.
if self.output_handler is None: if self.output_handler is None:
self.output_handler = asyncio.create_task(self._run_output_handler()) self.output_handler = asyncio.create_task(
self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running()) self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix. # TODO: Expand to support the full matrix.
if not "prompt_token_ids" in prompt: if "prompt_token_ids" not in prompt:
raise NotImplementedError( raise NotImplementedError(
"We currently only support TokensPrompt for P/D!") "We currently only support TokensPrompt for P/D!")
if lora_request is not None: if lora_request is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport LoRA for P/D!") "We currently do not support LoRA for P/D!")
if trace_headers is not None: if trace_headers is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport tracing for P/D!") "We currently do not support tracing for P/D!")
if prompt_adapter_request is not None: if prompt_adapter_request is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport prompt adapter for P/D!") "We currently do not support prompt adapter for P/D!")
if priority != 0: if priority != 0:
raise NotImplementedError( raise NotImplementedError(
"We currently do not support priority for P/D!") "We currently do not support priority for P/D!")
if request_id in self.queues: if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!") raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler. # Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue() q: asyncio.Queue[PDResponse] = asyncio.Queue()
self.queues[request_id] = q self.queues[request_id] = q
@ -240,10 +230,9 @@ class PDEngine:
# (1) Perform the Prefill. # (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
request = PDRequest( request = PDRequest(request_id=request_id,
request_id=request_id, prompt_token_ids=prompt_token_ids,
prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1 request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id) logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._prefill(request, q) pd_response = await self._prefill(request, q)
@ -301,7 +290,7 @@ class PDEngine:
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[list[SamplerOutput]] = None,
) -> None: ) -> None:
pass pass
@ -325,7 +314,7 @@ class PDEngine:
raise NotImplementedError raise NotImplementedError
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
False return False
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError raise NotImplementedError