mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 11:47:04 +08:00
pre-commit
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
79e465f557
commit
f51f182d64
@ -1,15 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import msgspec
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||
@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
CompletionOutput)
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
@ -31,6 +29,7 @@ logger = init_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 32000
|
||||
|
||||
|
||||
class PDEngine:
|
||||
"""
|
||||
PDEngine:
|
||||
@ -42,15 +41,10 @@ class PDEngine:
|
||||
* TODO: move under vllm/v1/engine one past prototype.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_addr: str,
|
||||
decode_addr: str,
|
||||
connector_addr: str,
|
||||
model_name: str
|
||||
):
|
||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||
connector_addr: str, model_name: str):
|
||||
# Request queues.
|
||||
self.queues: Dict[str, asyncio.Queue] = {}
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
|
||||
# Serialization encoder.
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
@ -64,22 +58,20 @@ class PDEngine:
|
||||
self.connector_addr = connector_addr
|
||||
self.decode_addr = decode_addr
|
||||
self.prefill_addr = prefill_addr
|
||||
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
self.log_running: Optional[asyncio.Task] = None
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor EngineClient to avoid needing this.
|
||||
self.model_config = ModelConfig(
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="auto",
|
||||
task="generate",
|
||||
seed=42
|
||||
)
|
||||
self.model_config = ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="auto",
|
||||
task="generate",
|
||||
seed=42)
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor EngineClient to avoid needing this.
|
||||
@ -103,9 +95,7 @@ class PDEngine:
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
ipc_paths = [
|
||||
self.connector_addr, self.decode_addr, self.prefill_addr
|
||||
]
|
||||
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
|
||||
for path in ipc_paths:
|
||||
socket_path = path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
@ -121,7 +111,7 @@ class PDEngine:
|
||||
distribute back to the generate() tasks.
|
||||
"""
|
||||
decoder = msgspec.msgpack.Decoder(PDResponse)
|
||||
|
||||
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
try:
|
||||
socket = self.ctx.socket(zmq.constants.PULL)
|
||||
@ -134,11 +124,11 @@ class PDEngine:
|
||||
self.queues[response.request_id].put_nowait(response)
|
||||
except:
|
||||
# TODO: actually handle failure and shutdown.
|
||||
raise
|
||||
raise
|
||||
finally:
|
||||
if socket is not None:
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
async def _prefill(
|
||||
self,
|
||||
request: PDRequest,
|
||||
@ -154,7 +144,7 @@ class PDEngine:
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Prefill Request.")
|
||||
|
||||
|
||||
async def _decode(
|
||||
self,
|
||||
request: PDRequest,
|
||||
@ -169,17 +159,16 @@ class PDEngine:
|
||||
finished = False
|
||||
while not finished:
|
||||
response = await q.get()
|
||||
logger.debug(f"{response}")
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Decode Request.")
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
|
||||
def _to_request_output(
|
||||
self,
|
||||
pd_response: PDResponse,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_token_ids: list[int],
|
||||
) -> RequestOutput:
|
||||
finished = pd_response.finish_reason is not None
|
||||
return RequestOutput(
|
||||
@ -187,15 +176,15 @@ class PDEngine:
|
||||
prompt=None,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
outputs=[CompletionOutput(
|
||||
index=0,
|
||||
text=pd_response.text,
|
||||
token_ids=pd_response.token_ids,
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=pd_response.finish_reason,
|
||||
stop_reason=pd_response.stop_reason
|
||||
)],
|
||||
outputs=[
|
||||
CompletionOutput(index=0,
|
||||
text=pd_response.text,
|
||||
token_ids=pd_response.token_ids,
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=pd_response.finish_reason,
|
||||
stop_reason=pd_response.stop_reason)
|
||||
],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@ -211,28 +200,29 @@ class PDEngine:
|
||||
) -> AsyncGenerator[RequestOutput]:
|
||||
# Start loops on first request.
|
||||
if self.output_handler is None:
|
||||
self.output_handler = asyncio.create_task(self._run_output_handler())
|
||||
self.output_handler = asyncio.create_task(
|
||||
self._run_output_handler())
|
||||
self.log_running = asyncio.create_task(self._run_log_running())
|
||||
|
||||
# TODO: Expand to support the full matrix.
|
||||
if not "prompt_token_ids" in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise NotImplementedError(
|
||||
"We currently only support TokensPrompt for P/D!")
|
||||
if lora_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport LoRA for P/D!")
|
||||
"We currently do not support LoRA for P/D!")
|
||||
if trace_headers is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport tracing for P/D!")
|
||||
"We currently do not support tracing for P/D!")
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport prompt adapter for P/D!")
|
||||
"We currently do not support prompt adapter for P/D!")
|
||||
if priority != 0:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support priority for P/D!")
|
||||
if request_id in self.queues:
|
||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||
|
||||
|
||||
# Queue to gather output from output_handler.
|
||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||
self.queues[request_id] = q
|
||||
@ -240,10 +230,9 @@ class PDEngine:
|
||||
# (1) Perform the Prefill.
|
||||
original_max_tokens = sampling_params.max_tokens
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
request = PDRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
request = PDRequest(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
logger.debug("Sending Prefill: %s", request.request_id)
|
||||
pd_response = await self._prefill(request, q)
|
||||
@ -301,7 +290,7 @@ class PDEngine:
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
model_output: Optional[list[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@ -325,7 +314,7 @@ class PDEngine:
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
False
|
||||
return False
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user