mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 04:47:03 +08:00
58 lines
1.4 KiB
Python
58 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Optional
|
|
|
|
import msgspec
|
|
|
|
from vllm import SamplingParams
|
|
from vllm.outputs import RequestOutput
|
|
|
|
# NOTE FOR DEVELOPERS:
|
|
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
|
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
|
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
|
|
|
|
|
class PDRequestType:
|
|
GENERATION = b"generation"
|
|
ABORT = b"abort"
|
|
|
|
|
|
class PDGenerationRequest(msgspec.Struct):
|
|
request_id: str
|
|
prompt_token_ids: list[int]
|
|
sampling_params: SamplingParams
|
|
# TODO: support multimodal inputs.
|
|
|
|
|
|
class PDAbortRequest(msgspec.Struct):
|
|
request_id: str
|
|
|
|
|
|
class PDResponseType:
|
|
GENERATION = b"generation"
|
|
FAILURE = b"failure"
|
|
|
|
|
|
class PDGenerationResponse(msgspec.Struct):
|
|
request_id: str
|
|
text: str
|
|
token_ids: list[int]
|
|
finish_reason: Optional[str] = None
|
|
stop_reason: Optional[str] = None
|
|
# TODO: support full protocol.
|
|
logprobs = None
|
|
|
|
@classmethod
|
|
def from_request_output(
|
|
self, request_output: RequestOutput) -> "PDGenerationResponse":
|
|
assert len(request_output.outputs) == 1, "Only support N=1 right now."
|
|
out = request_output.outputs[0]
|
|
return PDGenerationResponse(
|
|
request_id=request_output.request_id,
|
|
text=out.text,
|
|
token_ids=out.token_ids,
|
|
finish_reason=out.finish_reason,
|
|
stop_reason=out.stop_reason,
|
|
)
|