mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[DisaggEverything] Tokens in<>out /generate endpoint (#24261)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
d54a18a47e
commit
6f1e7f7226
49
examples/online_serving/token_generation_client.py
Normal file
49
examples/online_serving/token_generation_client.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import httpx
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate"
|
||||||
|
DUMMY_API_KEY = "empty"
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||||
|
|
||||||
|
transport = httpx.HTTPTransport()
|
||||||
|
headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"}
|
||||||
|
client = httpx.Client(
|
||||||
|
transport=transport,
|
||||||
|
base_url=GEN_ENDPOINT,
|
||||||
|
timeout=600,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "How many countries are in the EU?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main(client):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"token_ids": token_ids,
|
||||||
|
"sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
resp = client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
print(data)
|
||||||
|
print("-" * 50)
|
||||||
|
print("Token generation results:")
|
||||||
|
res = tokenizer.decode(data["choices"][0]["token_ids"])
|
||||||
|
print(res)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(client)
|
||||||
@ -10,3 +10,7 @@ mkdocs-minify-plugin
|
|||||||
regex
|
regex
|
||||||
ruff
|
ruff
|
||||||
pydantic
|
pydantic
|
||||||
|
|
||||||
|
# For generating argparse docs.
|
||||||
|
# Adding requirements here should only be used as a last resort.
|
||||||
|
msgspec # Need for multiple inheritance involving msgspec.Struct
|
||||||
262
tests/entrypoints/openai/test_serving_tokens.py
Normal file
262
tests/entrypoints/openai/test_serving_tokens.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.v1.engine.detokenizer import check_stop_strings
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||||
|
GEN_ENDPOINT = "/inference/v1/generate"
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocab_size(model_name):
|
||||||
|
config = ModelConfig(
|
||||||
|
model=model_name,
|
||||||
|
seed=0,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
return config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def messages():
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "How many countries are in the EU?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(request):
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"1024",
|
||||||
|
"--enforce-eager",
|
||||||
|
]
|
||||||
|
|
||||||
|
extra_args = getattr(request, "param", None)
|
||||||
|
if extra_args is not None:
|
||||||
|
args = args + (
|
||||||
|
list(extra_args)
|
||||||
|
if isinstance(extra_args, (list, tuple))
|
||||||
|
else [str(extra_args)]
|
||||||
|
)
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server: RemoteOpenAIServer):
|
||||||
|
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
|
||||||
|
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
transport=transport,
|
||||||
|
base_url=server.url_root,
|
||||||
|
timeout=600,
|
||||||
|
headers=headers,
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_endpoint(client):
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"token_ids": [1, 2, 3],
|
||||||
|
"sampling_params": {"max_tokens": 5},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
assert "choices" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_same_response_as_chat_completions(client, tokenizer, messages):
|
||||||
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False, # default with Qwen3
|
||||||
|
)
|
||||||
|
for ignore_eos in [True, False]:
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"token_ids": token_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
# NOTE coordinator will set this to skip detokenization
|
||||||
|
"detokenize": False,
|
||||||
|
"ignore_eos": ignore_eos,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
generate_data = generate_resp.json()
|
||||||
|
generate_res = tokenizer.decode(
|
||||||
|
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": False,
|
||||||
|
"ignore_eos": ignore_eos,
|
||||||
|
"chat_template_kwargs": dict(enable_thinking=False),
|
||||||
|
}
|
||||||
|
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||||
|
completions_data = completions_resp.json()
|
||||||
|
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert generate_res == completions_res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_string_workflow(client, tokenizer, messages):
|
||||||
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False, # default with Qwen3
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"token_ids": token_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"detokenize": False,
|
||||||
|
# stop strings are only supported when detokenize is True.
|
||||||
|
"stop": ["27 member"],
|
||||||
|
},
|
||||||
|
# TODO stream test is much more interesting
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
generate_resp.raise_for_status()
|
||||||
|
|
||||||
|
payload["sampling_params"]["stop"] = None
|
||||||
|
generate_resp = await client.post(
|
||||||
|
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
|
||||||
|
)
|
||||||
|
generate_data = generate_resp.json()
|
||||||
|
generate_res = tokenizer.decode(
|
||||||
|
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE This is under the responsibility of the coordinator
|
||||||
|
# stop_checker = StopChecker(
|
||||||
|
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
|
||||||
|
# )
|
||||||
|
stop_str, truncate_to = check_stop_strings(
|
||||||
|
generate_res, len(generate_res), ["27 member"], False
|
||||||
|
)
|
||||||
|
assert stop_str == "27 member"
|
||||||
|
# abort request that hit stop string (requires tokens-only mode)
|
||||||
|
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
|
||||||
|
# res.raise_for_status()
|
||||||
|
generate_res = generate_res[:truncate_to]
|
||||||
|
|
||||||
|
# Get stop_str response from chat completions
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": False,
|
||||||
|
"stop": ["27 member"],
|
||||||
|
"chat_template_kwargs": dict(enable_thinking=False),
|
||||||
|
}
|
||||||
|
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||||
|
completions_data = completions_resp.json()
|
||||||
|
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||||
|
assert generate_res == completions_res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"server",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
"Alice=charent/self_cognition_Alice",
|
||||||
|
"Bob=charent/self_cognition_Bob",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_generate_with_lora_adapter(client, tokenizer, messages):
|
||||||
|
# Verify adapters are listed
|
||||||
|
models_resp = await client.get("/v1/models")
|
||||||
|
models_resp.raise_for_status()
|
||||||
|
models = {m["id"] for m in models_resp.json().get("data", [])}
|
||||||
|
assert {"Alice", "Bob"}.issubset(models)
|
||||||
|
|
||||||
|
# Generate using a LoRA adapter by specifying its name as the model
|
||||||
|
payload = {
|
||||||
|
"model": "Alice",
|
||||||
|
"token_ids": [1, 2, 3],
|
||||||
|
"sampling_params": {"max_tokens": 5},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
assert "choices" in data
|
||||||
|
|
||||||
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False, # default with Qwen3
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"model": "Alice",
|
||||||
|
"token_ids": token_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"detokenize": False,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
generate_data = generate_resp.json()
|
||||||
|
generate_res = tokenizer.decode(
|
||||||
|
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "Alice",
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 24,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": False,
|
||||||
|
"chat_template_kwargs": dict(enable_thinking=False),
|
||||||
|
}
|
||||||
|
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||||
|
completions_data = completions_resp.json()
|
||||||
|
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert generate_res == completions_res
|
||||||
@ -566,6 +566,7 @@ class EngineArgs:
|
|||||||
kv_offloading_backend: KVOffloadingBackend | None = (
|
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||||
CacheConfig.kv_offloading_backend
|
CacheConfig.kv_offloading_backend
|
||||||
)
|
)
|
||||||
|
tokens_only: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
@ -1495,6 +1496,10 @@ class EngineArgs:
|
|||||||
else ParallelConfig.data_parallel_rpc_port
|
else ParallelConfig.data_parallel_rpc_port
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.tokens_only and not model_config.skip_tokenizer_init:
|
||||||
|
model_config.skip_tokenizer_init = True
|
||||||
|
logger.info("Skipping tokenizer initialization for tokens-only mode.")
|
||||||
|
|
||||||
# Forward the deprecated CLI args to the EPLB config.
|
# Forward the deprecated CLI args to the EPLB config.
|
||||||
if self.num_redundant_experts is not None:
|
if self.num_redundant_experts is not None:
|
||||||
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
||||||
|
|||||||
@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
IOProcessorResponse,
|
IOProcessorResponse,
|
||||||
PoolingBytesResponse,
|
PoolingBytesResponse,
|
||||||
PoolingRequest,
|
PoolingRequest,
|
||||||
@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
|||||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||||
|
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
OpenAIServingTranscription,
|
OpenAIServingTranscription,
|
||||||
OpenAIServingTranslation,
|
OpenAIServingTranslation,
|
||||||
@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
|
|||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||||
|
return request.app.state.serving_tokens
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health", response_class=Response)
|
@router.get("/health", response_class=Response)
|
||||||
async def health(raw_request: Request) -> Response:
|
async def health(raw_request: Request) -> Response:
|
||||||
"""Health check."""
|
"""Health check."""
|
||||||
@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/inference/v1/generate",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def generate(request: GenerateRequest, raw_request: Request):
|
||||||
|
handler = generate_tokens(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support generate tokens API"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generator = await handler.serve_tokens(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(generator, GenerateResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
app = sagemaker_standards.bootstrap(app)
|
app = sagemaker_standards.bootstrap(app)
|
||||||
|
# Optional endpoints
|
||||||
|
if args.tokens_only:
|
||||||
|
|
||||||
|
@app.post("/abort_requests")
|
||||||
|
async def abort_requests(raw_request: Request):
|
||||||
|
"""
|
||||||
|
Abort one or more requests. To be used in a
|
||||||
|
Disaggregated Everything setup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail=f"JSON decode error: {e}",
|
||||||
|
) from e
|
||||||
|
request_ids = body.get("request_ids")
|
||||||
|
if request_ids is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail="Missing 'request_ids' in request body",
|
||||||
|
)
|
||||||
|
# Abort requests in background
|
||||||
|
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@ -1851,6 +1918,20 @@ async def init_app_state(
|
|||||||
if "generate" in supported_tasks
|
if "generate" in supported_tasks
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
state.serving_tokens = (
|
||||||
|
ServingTokens(
|
||||||
|
engine_client,
|
||||||
|
state.openai_serving_models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
log_error_stack=args.log_error_stack,
|
||||||
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
|
enable_log_outputs=args.enable_log_outputs,
|
||||||
|
force_no_detokenize=args.tokens_only,
|
||||||
|
)
|
||||||
|
if "generate" in supported_tasks
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||||
state.server_load_metrics = 0
|
state.server_load_metrics = 0
|
||||||
|
|||||||
@ -189,6 +189,11 @@ class FrontendArgs:
|
|||||||
Helps mitigate header abuse. Default: 256."""
|
Helps mitigate header abuse. Default: 256."""
|
||||||
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
|
||||||
"""If set to True, log the stack trace of error responses"""
|
"""If set to True, log the stack trace of error responses"""
|
||||||
|
tokens_only: bool = False
|
||||||
|
"""
|
||||||
|
If set to True, only enable the Tokens In<>Out endpoint.
|
||||||
|
This is intended for use in a Disaggregated Everything setup.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
|||||||
@ -3220,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel):
|
|||||||
|
|
||||||
words: list[TranslationWord] | None = None
|
words: list[TranslationWord] | None = None
|
||||||
"""Extracted words and their corresponding timestamps."""
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
####### Tokens IN <> Tokens OUT #######
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
token_ids: list[int]
|
||||||
|
"""The token ids to generate text from."""
|
||||||
|
|
||||||
|
# features: MultiModalFeatureSpec
|
||||||
|
# TODO (NickLucche): implement once Renderer work is completed
|
||||||
|
features: str | None = None
|
||||||
|
"""The processed MM inputs for the model."""
|
||||||
|
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
"""The sampling parameters for the model."""
|
||||||
|
|
||||||
|
model: str | None = None
|
||||||
|
|
||||||
|
stream: bool | None = False
|
||||||
|
stream_options: StreamOptions | None = None
|
||||||
|
cache_salt: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the prefix cache will be salted with the provided "
|
||||||
|
"string to prevent an attacker to guess prompts in multi-user "
|
||||||
|
"environments. The salt should be random, protected from "
|
||||||
|
"access by 3rd parties, and long enough to be "
|
||||||
|
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||||
|
"to 256 bit)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
kv_transfer_params: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
logprobs: ChatCompletionLogProbs | None = None
|
||||||
|
# per OpenAI spec this is the default
|
||||||
|
finish_reason: str | None = "stop"
|
||||||
|
token_ids: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponse(BaseModel):
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
choices: list[GenerateResponseChoice]
|
||||||
|
|
||||||
|
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||||
|
|
||||||
|
kv_transfer_params: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.",
|
||||||
|
)
|
||||||
|
|||||||
@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
IOProcessorRequest,
|
IOProcessorRequest,
|
||||||
PoolingResponse,
|
PoolingResponse,
|
||||||
RerankRequest,
|
RerankRequest,
|
||||||
@ -134,6 +136,7 @@ AnyRequest: TypeAlias = (
|
|||||||
| SpeechToTextRequest
|
| SpeechToTextRequest
|
||||||
| ResponsesRequest
|
| ResponsesRequest
|
||||||
| IOProcessorRequest
|
| IOProcessorRequest
|
||||||
|
| GenerateRequest
|
||||||
)
|
)
|
||||||
|
|
||||||
AnyResponse: TypeAlias = (
|
AnyResponse: TypeAlias = (
|
||||||
@ -145,6 +148,7 @@ AnyResponse: TypeAlias = (
|
|||||||
| PoolingResponse
|
| PoolingResponse
|
||||||
| ClassificationResponse
|
| ClassificationResponse
|
||||||
| ScoreResponse
|
| ScoreResponse
|
||||||
|
| GenerateResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
269
vllm/entrypoints/openai/serving_tokens.py
Normal file
269
vllm/entrypoints/openai/serving_tokens.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from collections.abc import Sequence as GenericSequence
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionLogProb,
|
||||||
|
ChatCompletionLogProbs,
|
||||||
|
ChatCompletionLogProbsContent,
|
||||||
|
ErrorResponse,
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
GenerateResponseChoice,
|
||||||
|
PromptTokenUsageInfo,
|
||||||
|
RequestResponseMetadata,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils.collection_utils import as_list
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ServingTokens(OpenAIServing):
|
||||||
|
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
models: OpenAIServingModels,
|
||||||
|
*,
|
||||||
|
request_logger: RequestLogger | None,
|
||||||
|
force_no_detokenize: bool = False,
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
log_error_stack: bool = False,
|
||||||
|
enable_prompt_tokens_details: bool = False,
|
||||||
|
enable_log_outputs: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
models=models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||||
|
log_error_stack=log_error_stack)
|
||||||
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||||
|
self.enable_log_outputs = enable_log_outputs
|
||||||
|
self.force_no_detokenize = force_no_detokenize
|
||||||
|
if force_no_detokenize:
|
||||||
|
logger.info("Tokens-only mode is enabled, skipping detokenization "
|
||||||
|
"step for incoming requests.")
|
||||||
|
|
||||||
|
async def serve_tokens(
|
||||||
|
self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
raw_request: Request | None = None
|
||||||
|
) -> GenerateResponse | ErrorResponse:
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
logger.error("Error with model %s", error_check_ret)
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
|
lora_request = None
|
||||||
|
lora_request = self._maybe_get_adapters(request,
|
||||||
|
supports_default_mm_loras=True)
|
||||||
|
|
||||||
|
model_name = self.models.model_name(lora_request)
|
||||||
|
|
||||||
|
request_id = "generate-tokens-" \
|
||||||
|
f"{self._base_request_id(raw_request, request.request_id)}"
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
|
||||||
|
# completed
|
||||||
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
|
||||||
|
if request.features is not None:
|
||||||
|
engine_prompt["multi_modal_data"] = None
|
||||||
|
|
||||||
|
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||||
|
engine_prompt["cache_salt"] = request.cache_salt
|
||||||
|
|
||||||
|
# Schedule the request and get the result generator.
|
||||||
|
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
||||||
|
try:
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
if self.force_no_detokenize:
|
||||||
|
sampling_params.detokenize = False
|
||||||
|
|
||||||
|
self._log_inputs(request_id,
|
||||||
|
request.token_ids,
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
trace_headers = (None if raw_request is None else await
|
||||||
|
self._get_trace_headers(raw_request.headers))
|
||||||
|
|
||||||
|
result_generator = self.engine_client.generate(
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# TODO(NickLucche): Implement streaming response
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert result_generator is not None
|
||||||
|
return await self.serve_tokens_full_generator(
|
||||||
|
request, result_generator, request_id, model_name,
|
||||||
|
request_metadata)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
async def serve_tokens_full_generator(
|
||||||
|
self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
result_generator: AsyncGenerator[RequestOutput, None],
|
||||||
|
request_id: str,
|
||||||
|
model_name: str,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
) -> ErrorResponse | GenerateResponse:
|
||||||
|
|
||||||
|
created_time = int(time.time())
|
||||||
|
final_res: RequestOutput | None = None
|
||||||
|
sampling_params: SamplingParams = request.sampling_params
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
final_res = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
assert final_res is not None
|
||||||
|
|
||||||
|
choices: list[GenerateResponseChoice] = []
|
||||||
|
num_generated_tokens = 0
|
||||||
|
for output in final_res.outputs:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
|
# This is top_logprobs in completions API
|
||||||
|
if sampling_params.logprobs:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
|
logprobs = self._create_tokens_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=out_logprobs,
|
||||||
|
num_output_top_logprobs=sampling_params.logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
choice_data = GenerateResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason
|
||||||
|
if output.finish_reason else "stop",
|
||||||
|
token_ids=as_list(output.token_ids))
|
||||||
|
|
||||||
|
choices.append(choice_data)
|
||||||
|
num_generated_tokens += len(output.token_ids)
|
||||||
|
|
||||||
|
assert final_res.prompt_token_ids is not None
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
if final_res.encoder_prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||||
|
|
||||||
|
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
num_generated_tokens)
|
||||||
|
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||||
|
# This info is not available at the /coordinator level
|
||||||
|
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||||
|
cached_tokens=final_res.num_cached_tokens)
|
||||||
|
|
||||||
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
|
response = GenerateResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||||
|
kv_transfer_params=final_res.kv_transfer_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log complete response if output logging is enabled
|
||||||
|
if self.enable_log_outputs and self.request_logger:
|
||||||
|
for choice in choices:
|
||||||
|
# Get the corresponding output token IDs
|
||||||
|
output_token_ids = None
|
||||||
|
if choice.index < len(final_res.outputs):
|
||||||
|
output_token_ids = final_res.outputs[
|
||||||
|
choice.index].token_ids
|
||||||
|
|
||||||
|
if output_token_ids:
|
||||||
|
# Log token_ids only.
|
||||||
|
self.request_logger.log_outputs(
|
||||||
|
request_id=request_id,
|
||||||
|
outputs="",
|
||||||
|
output_token_ids=output_token_ids,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _create_tokens_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||||
|
num_output_top_logprobs: int | None = None,
|
||||||
|
) -> ChatCompletionLogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
token = f"token_id:{token_id}"
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None or step_top_logprobs.get(
|
||||||
|
token_id) is None:
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(token=token, ))
|
||||||
|
else:
|
||||||
|
step_token = step_top_logprobs[token_id]
|
||||||
|
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=token,
|
||||||
|
logprob=max(step_token.logprob, -9999.0),
|
||||||
|
top_logprobs=[
|
||||||
|
ChatCompletionLogProb(
|
||||||
|
token=token,
|
||||||
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
|
) for i, p in enumerate(step_top_logprobs.items())
|
||||||
|
if num_output_top_logprobs
|
||||||
|
and i < num_output_top_logprobs
|
||||||
|
]))
|
||||||
|
|
||||||
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import LogitsProcessor
|
from vllm.logits_process import LogitsProcessor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.v1.serial_utils import PydanticMsgspecMixin
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -122,6 +123,7 @@ class RequestOutputKind(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class SamplingParams(
|
class SamplingParams(
|
||||||
|
PydanticMsgspecMixin,
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
# required for @cached_property.
|
# required for @cached_property.
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||||
|
from vllm.v1.serial_utils import UtilityResult
|
||||||
|
|
||||||
# These are possible values of RequestOutput.finish_reason,
|
# These are possible values of RequestOutput.finish_reason,
|
||||||
# so form part of the external API.
|
# so form part of the external API.
|
||||||
@ -131,13 +132,6 @@ class EngineCoreOutput(
|
|||||||
return self.finish_reason is not None
|
return self.finish_reason is not None
|
||||||
|
|
||||||
|
|
||||||
class UtilityResult:
|
|
||||||
"""Wrapper for special handling when serializing/deserializing."""
|
|
||||||
|
|
||||||
def __init__(self, r: Any = None):
|
|
||||||
self.result = r
|
|
||||||
|
|
||||||
|
|
||||||
class UtilityOutput(
|
class UtilityOutput(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
array_like=True, # type: ignore[call-arg]
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Any, TypeAlias
|
from typing import Any, TypeAlias, get_type_hints
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -16,6 +16,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
from msgspec import msgpack
|
from msgspec import msgpack
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -32,7 +34,6 @@ from vllm.multimodal.inputs import (
|
|||||||
NestedTensors,
|
NestedTensors,
|
||||||
)
|
)
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.engine import UtilityResult
|
|
||||||
from vllm.v1.utils import tensor_data
|
from vllm.v1.utils import tensor_data
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -104,6 +105,13 @@ def _decode_type_info_recursive(
|
|||||||
return convert_fn(type_info, data)
|
return convert_fn(type_info, data)
|
||||||
|
|
||||||
|
|
||||||
|
class UtilityResult:
|
||||||
|
"""Wrapper for special handling when serializing/deserializing."""
|
||||||
|
|
||||||
|
def __init__(self, r: Any = None):
|
||||||
|
self.result = r
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
"""Encoder with custom torch tensor and numpy array serialization.
|
"""Encoder with custom torch tensor and numpy array serialization.
|
||||||
|
|
||||||
@ -469,3 +477,56 @@ def run_method(
|
|||||||
else:
|
else:
|
||||||
func = partial(method, obj) # type: ignore
|
func = partial(method, obj) # type: ignore
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticMsgspecMixin:
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
"""
|
||||||
|
Make msgspec.Struct compatible with Pydantic, respecting defaults.
|
||||||
|
Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
|
||||||
|
API as input or in `/docs`. Note this is cached by Pydantic and not
|
||||||
|
called on every validation.
|
||||||
|
"""
|
||||||
|
msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
|
||||||
|
type_hints = get_type_hints(source_type)
|
||||||
|
|
||||||
|
# Build the Pydantic typed_dict_field for each msgspec field
|
||||||
|
fields = {}
|
||||||
|
for name, hint in type_hints.items():
|
||||||
|
msgspec_field = msgspec_fields[name]
|
||||||
|
|
||||||
|
# typed_dict_field using the handler to get the schema
|
||||||
|
field_schema = handler(hint)
|
||||||
|
|
||||||
|
# Add default value to the schema.
|
||||||
|
if msgspec_field.default_factory is not msgspec.NODEFAULT:
|
||||||
|
wrapped_schema = core_schema.with_default_schema(
|
||||||
|
schema=field_schema,
|
||||||
|
default_factory=msgspec_field.default_factory,
|
||||||
|
)
|
||||||
|
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||||
|
elif msgspec_field.default is not msgspec.NODEFAULT:
|
||||||
|
wrapped_schema = core_schema.with_default_schema(
|
||||||
|
schema=field_schema,
|
||||||
|
default=msgspec_field.default,
|
||||||
|
)
|
||||||
|
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||||
|
else:
|
||||||
|
# No default, so Pydantic will treat it as required
|
||||||
|
fields[name] = core_schema.typed_dict_field(field_schema)
|
||||||
|
return core_schema.no_info_after_validator_function(
|
||||||
|
cls._validate_msgspec,
|
||||||
|
core_schema.typed_dict_schema(fields),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_msgspec(cls, value: Any) -> Any:
|
||||||
|
"""Validate and convert input to msgspec.Struct instance."""
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return cls(**value)
|
||||||
|
return msgspec.convert(value, type=cls)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user