pre-commit

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 20:18:50 +00:00
parent 79e465f557
commit f51f182d64

View File

@ -1,15 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import msgspec
import os
from collections.abc import AsyncGenerator
from typing import Dict, List, Mapping, Optional
from collections.abc import AsyncGenerator, Mapping
from typing import Optional
import msgspec
import zmq
import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
CompletionOutput)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -31,6 +29,7 @@ logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
class PDEngine:
"""
PDEngine:
@ -42,15 +41,10 @@ class PDEngine:
* TODO: move under vllm/v1/engine one past prototype.
"""
def __init__(
self,
prefill_addr: str,
decode_addr: str,
connector_addr: str,
model_name: str
):
def __init__(self, prefill_addr: str, decode_addr: str,
connector_addr: str, model_name: str):
# Request queues.
self.queues: Dict[str, asyncio.Queue] = {}
self.queues: dict[str, asyncio.Queue] = {}
# Serialization encoder.
self.encoder = msgspec.msgpack.Encoder()
@ -64,22 +58,20 @@ class PDEngine:
self.connector_addr = connector_addr
self.decode_addr = decode_addr
self.prefill_addr = prefill_addr
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
self.model_config = ModelConfig(
model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
task="generate",
seed=42
)
self.model_config = ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
task="generate",
seed=42)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
@ -103,9 +95,7 @@ class PDEngine:
if (task := self.output_handler) is not None:
task.cancel()
ipc_paths = [
self.connector_addr, self.decode_addr, self.prefill_addr
]
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
for path in ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
@ -121,7 +111,7 @@ class PDEngine:
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
socket = self.ctx.socket(zmq.constants.PULL)
@ -134,11 +124,11 @@ class PDEngine:
self.queues[response.request_id].put_nowait(response)
except:
# TODO: actually handle failure and shutdown.
raise
raise
finally:
if socket is not None:
socket.close(linger=0)
async def _prefill(
self,
request: PDRequest,
@ -154,7 +144,7 @@ class PDEngine:
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.")
async def _decode(
self,
request: PDRequest,
@ -169,17 +159,16 @@ class PDEngine:
finished = False
while not finished:
response = await q.get()
logger.debug(f"{response}")
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.")
finished = response.finish_reason is not None
yield response
def _to_request_output(
self,
pd_response: PDResponse,
prompt_token_ids: List[int],
prompt_token_ids: list[int],
) -> RequestOutput:
finished = pd_response.finish_reason is not None
return RequestOutput(
@ -187,15 +176,15 @@ class PDEngine:
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
outputs=[CompletionOutput(
index=0,
text=pd_response.text,
token_ids=pd_response.token_ids,
cumulative_logprob=None,
logprobs=None,
finish_reason=pd_response.finish_reason,
stop_reason=pd_response.stop_reason
)],
outputs=[
CompletionOutput(index=0,
text=pd_response.text,
token_ids=pd_response.token_ids,
cumulative_logprob=None,
logprobs=None,
finish_reason=pd_response.finish_reason,
stop_reason=pd_response.stop_reason)
],
finished=finished,
)
@ -211,28 +200,29 @@ class PDEngine:
) -> AsyncGenerator[RequestOutput]:
# Start loops on first request.
if self.output_handler is None:
self.output_handler = asyncio.create_task(self._run_output_handler())
self.output_handler = asyncio.create_task(
self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix.
if not "prompt_token_ids" in prompt:
if "prompt_token_ids" not in prompt:
raise NotImplementedError(
"We currently only support TokensPrompt for P/D!")
if lora_request is not None:
raise NotImplementedError(
"We currently do not suppport LoRA for P/D!")
"We currently do not support LoRA for P/D!")
if trace_headers is not None:
raise NotImplementedError(
"We currently do not suppport tracing for P/D!")
"We currently do not support tracing for P/D!")
if prompt_adapter_request is not None:
raise NotImplementedError(
"We currently do not suppport prompt adapter for P/D!")
"We currently do not support prompt adapter for P/D!")
if priority != 0:
raise NotImplementedError(
"We currently do not support priority for P/D!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue()
self.queues[request_id] = q
@ -240,10 +230,9 @@ class PDEngine:
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"]
request = PDRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
request = PDRequest(request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._prefill(request, q)
@ -301,7 +290,7 @@ class PDEngine:
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
pass
@ -325,7 +314,7 @@ class PDEngine:
raise NotImplementedError
async def is_sleeping(self) -> bool:
False
return False
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError