pre-commit

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 20:18:50 +00:00
parent 79e465f557
commit f51f182d64

View File

@ -1,15 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import msgspec
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Mapping
from typing import Dict, List, Mapping, Optional from typing import Optional
import msgspec
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
CompletionOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -31,6 +29,7 @@ logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000 DEFAULT_MAX_TOKENS = 32000
class PDEngine: class PDEngine:
""" """
PDEngine: PDEngine:
@ -42,15 +41,10 @@ class PDEngine:
* TODO: move under vllm/v1/engine one past prototype. * TODO: move under vllm/v1/engine one past prototype.
""" """
def __init__( def __init__(self, prefill_addr: str, decode_addr: str,
self, connector_addr: str, model_name: str):
prefill_addr: str,
decode_addr: str,
connector_addr: str,
model_name: str
):
# Request queues. # Request queues.
self.queues: Dict[str, asyncio.Queue] = {} self.queues: dict[str, asyncio.Queue] = {}
# Serialization encoder. # Serialization encoder.
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
@ -71,15 +65,13 @@ class PDEngine:
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor EngineClient to avoid needing this.
self.model_config = ModelConfig( self.model_config = ModelConfig(model=model_name,
model=model_name, tokenizer=model_name,
tokenizer=model_name, tokenizer_mode="auto",
tokenizer_mode="auto", trust_remote_code=False,
trust_remote_code=False, dtype="auto",
dtype="auto", task="generate",
task="generate", seed=42)
seed=42
)
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor EngineClient to avoid needing this.
@ -103,9 +95,7 @@ class PDEngine:
if (task := self.output_handler) is not None: if (task := self.output_handler) is not None:
task.cancel() task.cancel()
ipc_paths = [ ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
self.connector_addr, self.decode_addr, self.prefill_addr
]
for path in ipc_paths: for path in ipc_paths:
socket_path = path.replace("ipc://", "") socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path): if os.path.exists(socket_path):
@ -169,7 +159,6 @@ class PDEngine:
finished = False finished = False
while not finished: while not finished:
response = await q.get() response = await q.get()
logger.debug(f"{response}")
if not response.success: if not response.success:
# TODO: actual error handling and shutdown. # TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.") raise Exception("Failed Decode Request.")
@ -179,7 +168,7 @@ class PDEngine:
def _to_request_output( def _to_request_output(
self, self,
pd_response: PDResponse, pd_response: PDResponse,
prompt_token_ids: List[int], prompt_token_ids: list[int],
) -> RequestOutput: ) -> RequestOutput:
finished = pd_response.finish_reason is not None finished = pd_response.finish_reason is not None
return RequestOutput( return RequestOutput(
@ -187,15 +176,15 @@ class PDEngine:
prompt=None, prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, prompt_logprobs=None,
outputs=[CompletionOutput( outputs=[
index=0, CompletionOutput(index=0,
text=pd_response.text, text=pd_response.text,
token_ids=pd_response.token_ids, token_ids=pd_response.token_ids,
cumulative_logprob=None, cumulative_logprob=None,
logprobs=None, logprobs=None,
finish_reason=pd_response.finish_reason, finish_reason=pd_response.finish_reason,
stop_reason=pd_response.stop_reason stop_reason=pd_response.stop_reason)
)], ],
finished=finished, finished=finished,
) )
@ -211,22 +200,23 @@ class PDEngine:
) -> AsyncGenerator[RequestOutput]: ) -> AsyncGenerator[RequestOutput]:
# Start loops on first request. # Start loops on first request.
if self.output_handler is None: if self.output_handler is None:
self.output_handler = asyncio.create_task(self._run_output_handler()) self.output_handler = asyncio.create_task(
self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running()) self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix. # TODO: Expand to support the full matrix.
if not "prompt_token_ids" in prompt: if "prompt_token_ids" not in prompt:
raise NotImplementedError( raise NotImplementedError(
"We currently only support TokensPrompt for P/D!") "We currently only support TokensPrompt for P/D!")
if lora_request is not None: if lora_request is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport LoRA for P/D!") "We currently do not support LoRA for P/D!")
if trace_headers is not None: if trace_headers is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport tracing for P/D!") "We currently do not support tracing for P/D!")
if prompt_adapter_request is not None: if prompt_adapter_request is not None:
raise NotImplementedError( raise NotImplementedError(
"We currently do not suppport prompt adapter for P/D!") "We currently do not support prompt adapter for P/D!")
if priority != 0: if priority != 0:
raise NotImplementedError( raise NotImplementedError(
"We currently do not support priority for P/D!") "We currently do not support priority for P/D!")
@ -240,10 +230,9 @@ class PDEngine:
# (1) Perform the Prefill. # (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
request = PDRequest( request = PDRequest(request_id=request_id,
request_id=request_id, prompt_token_ids=prompt_token_ids,
prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1 request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id) logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._prefill(request, q) pd_response = await self._prefill(request, q)
@ -301,7 +290,7 @@ class PDEngine:
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[list[SamplerOutput]] = None,
) -> None: ) -> None:
pass pass
@ -325,7 +314,7 @@ class PDEngine:
raise NotImplementedError raise NotImplementedError
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
False return False
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError raise NotImplementedError