mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-28 15:42:14 +08:00
pre-commit
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
79e465f557
commit
f51f182d64
@ -1,15 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import msgspec
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
from typing import Dict, List, Mapping, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||||
@ -18,8 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||||
CompletionOutput)
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
@ -31,6 +29,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
DEFAULT_MAX_TOKENS = 32000
|
DEFAULT_MAX_TOKENS = 32000
|
||||||
|
|
||||||
|
|
||||||
class PDEngine:
|
class PDEngine:
|
||||||
"""
|
"""
|
||||||
PDEngine:
|
PDEngine:
|
||||||
@ -42,15 +41,10 @@ class PDEngine:
|
|||||||
* TODO: move under vllm/v1/engine one past prototype.
|
* TODO: move under vllm/v1/engine one past prototype.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||||
self,
|
connector_addr: str, model_name: str):
|
||||||
prefill_addr: str,
|
|
||||||
decode_addr: str,
|
|
||||||
connector_addr: str,
|
|
||||||
model_name: str
|
|
||||||
):
|
|
||||||
# Request queues.
|
# Request queues.
|
||||||
self.queues: Dict[str, asyncio.Queue] = {}
|
self.queues: dict[str, asyncio.Queue] = {}
|
||||||
|
|
||||||
# Serialization encoder.
|
# Serialization encoder.
|
||||||
self.encoder = msgspec.msgpack.Encoder()
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
@ -71,15 +65,13 @@ class PDEngine:
|
|||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
# Dummy: needed for EngineClient Protocol.
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(model=model_name,
|
||||||
model=model_name,
|
tokenizer=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer_mode="auto",
|
||||||
tokenizer_mode="auto",
|
trust_remote_code=False,
|
||||||
trust_remote_code=False,
|
dtype="auto",
|
||||||
dtype="auto",
|
task="generate",
|
||||||
task="generate",
|
seed=42)
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
# Dummy: needed for EngineClient Protocol.
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
@ -103,9 +95,7 @@ class PDEngine:
|
|||||||
if (task := self.output_handler) is not None:
|
if (task := self.output_handler) is not None:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
ipc_paths = [
|
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
|
||||||
self.connector_addr, self.decode_addr, self.prefill_addr
|
|
||||||
]
|
|
||||||
for path in ipc_paths:
|
for path in ipc_paths:
|
||||||
socket_path = path.replace("ipc://", "")
|
socket_path = path.replace("ipc://", "")
|
||||||
if os.path.exists(socket_path):
|
if os.path.exists(socket_path):
|
||||||
@ -169,7 +159,6 @@ class PDEngine:
|
|||||||
finished = False
|
finished = False
|
||||||
while not finished:
|
while not finished:
|
||||||
response = await q.get()
|
response = await q.get()
|
||||||
logger.debug(f"{response}")
|
|
||||||
if not response.success:
|
if not response.success:
|
||||||
# TODO: actual error handling and shutdown.
|
# TODO: actual error handling and shutdown.
|
||||||
raise Exception("Failed Decode Request.")
|
raise Exception("Failed Decode Request.")
|
||||||
@ -179,7 +168,7 @@ class PDEngine:
|
|||||||
def _to_request_output(
|
def _to_request_output(
|
||||||
self,
|
self,
|
||||||
pd_response: PDResponse,
|
pd_response: PDResponse,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: list[int],
|
||||||
) -> RequestOutput:
|
) -> RequestOutput:
|
||||||
finished = pd_response.finish_reason is not None
|
finished = pd_response.finish_reason is not None
|
||||||
return RequestOutput(
|
return RequestOutput(
|
||||||
@ -187,15 +176,15 @@ class PDEngine:
|
|||||||
prompt=None,
|
prompt=None,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
prompt_logprobs=None,
|
prompt_logprobs=None,
|
||||||
outputs=[CompletionOutput(
|
outputs=[
|
||||||
index=0,
|
CompletionOutput(index=0,
|
||||||
text=pd_response.text,
|
text=pd_response.text,
|
||||||
token_ids=pd_response.token_ids,
|
token_ids=pd_response.token_ids,
|
||||||
cumulative_logprob=None,
|
cumulative_logprob=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
finish_reason=pd_response.finish_reason,
|
finish_reason=pd_response.finish_reason,
|
||||||
stop_reason=pd_response.stop_reason
|
stop_reason=pd_response.stop_reason)
|
||||||
)],
|
],
|
||||||
finished=finished,
|
finished=finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -211,22 +200,23 @@ class PDEngine:
|
|||||||
) -> AsyncGenerator[RequestOutput]:
|
) -> AsyncGenerator[RequestOutput]:
|
||||||
# Start loops on first request.
|
# Start loops on first request.
|
||||||
if self.output_handler is None:
|
if self.output_handler is None:
|
||||||
self.output_handler = asyncio.create_task(self._run_output_handler())
|
self.output_handler = asyncio.create_task(
|
||||||
|
self._run_output_handler())
|
||||||
self.log_running = asyncio.create_task(self._run_log_running())
|
self.log_running = asyncio.create_task(self._run_log_running())
|
||||||
|
|
||||||
# TODO: Expand to support the full matrix.
|
# TODO: Expand to support the full matrix.
|
||||||
if not "prompt_token_ids" in prompt:
|
if "prompt_token_ids" not in prompt:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We currently only support TokensPrompt for P/D!")
|
"We currently only support TokensPrompt for P/D!")
|
||||||
if lora_request is not None:
|
if lora_request is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We currently do not suppport LoRA for P/D!")
|
"We currently do not support LoRA for P/D!")
|
||||||
if trace_headers is not None:
|
if trace_headers is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We currently do not suppport tracing for P/D!")
|
"We currently do not support tracing for P/D!")
|
||||||
if prompt_adapter_request is not None:
|
if prompt_adapter_request is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We currently do not suppport prompt adapter for P/D!")
|
"We currently do not support prompt adapter for P/D!")
|
||||||
if priority != 0:
|
if priority != 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"We currently do not support priority for P/D!")
|
"We currently do not support priority for P/D!")
|
||||||
@ -240,10 +230,9 @@ class PDEngine:
|
|||||||
# (1) Perform the Prefill.
|
# (1) Perform the Prefill.
|
||||||
original_max_tokens = sampling_params.max_tokens
|
original_max_tokens = sampling_params.max_tokens
|
||||||
prompt_token_ids = prompt["prompt_token_ids"]
|
prompt_token_ids = prompt["prompt_token_ids"]
|
||||||
request = PDRequest(
|
request = PDRequest(request_id=request_id,
|
||||||
request_id=request_id,
|
prompt_token_ids=prompt_token_ids,
|
||||||
prompt_token_ids=prompt_token_ids,
|
sampling_params=sampling_params)
|
||||||
sampling_params=sampling_params)
|
|
||||||
request.sampling_params.max_tokens = 1
|
request.sampling_params.max_tokens = 1
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
logger.debug("Sending Prefill: %s", request.request_id)
|
||||||
pd_response = await self._prefill(request, q)
|
pd_response = await self._prefill(request, q)
|
||||||
@ -301,7 +290,7 @@ class PDEngine:
|
|||||||
async def do_log_stats(
|
async def do_log_stats(
|
||||||
self,
|
self,
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||||
model_output: Optional[List[SamplerOutput]] = None,
|
model_output: Optional[list[SamplerOutput]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -325,7 +314,7 @@ class PDEngine:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def is_sleeping(self) -> bool:
|
async def is_sleeping(self) -> bool:
|
||||||
False
|
return False
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user