[DisaggEverything] Tokens in<>out /generate endpoint (#24261)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Nicolò Lucchesi 2025-11-14 17:58:01 +01:00 committed by GitHub
parent d54a18a47e
commit 6f1e7f7226
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 822 additions and 9 deletions

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
from transformers import AutoTokenizer
GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate"
DUMMY_API_KEY = "empty"
MODEL_NAME = "Qwen/Qwen3-0.6B"
transport = httpx.HTTPTransport()
headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"}
client = httpx.Client(
transport=transport,
base_url=GEN_ENDPOINT,
timeout=600,
headers=headers,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "How many countries are in the EU?"},
]
def main(client):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False,
)
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False},
"stream": False,
}
resp = client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
print(data)
print("-" * 50)
print("Token generation results:")
res = tokenizer.decode(data["choices"][0]["token_ids"])
print(res)
print("-" * 50)
if __name__ == "__main__":
main(client)

View File

@ -10,3 +10,7 @@ mkdocs-minify-plugin
regex regex
ruff ruff
pydantic pydantic
# For generating argparse docs.
# Adding requirements here should only be used as a last resort.
msgspec # Need for multiple inheritance involving msgspec.Struct

View File

@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
import pytest
import pytest_asyncio
from transformers import AutoTokenizer
from vllm.config import ModelConfig
from vllm.v1.engine.detokenizer import check_stop_strings
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
GEN_ENDPOINT = "/inference/v1/generate"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
seed=0,
dtype="bfloat16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def messages():
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "How many countries are in the EU?"},
]
@pytest.fixture(scope="module")
def server(request):
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
]
extra_args = getattr(request, "param", None)
if extra_args is not None:
args = args + (
list(extra_args)
if isinstance(extra_args, (list, tuple))
else [str(extra_args)]
)
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
async with httpx.AsyncClient(
transport=transport,
base_url=server.url_root,
timeout=600,
headers=headers,
) as c:
yield c
@pytest.mark.asyncio
async def test_generate_endpoint(client):
payload = {
"model": MODEL_NAME,
"token_ids": [1, 2, 3],
"sampling_params": {"max_tokens": 5},
"stream": False,
}
resp = await client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
assert "choices" in data
@pytest.mark.asyncio
async def test_same_response_as_chat_completions(client, tokenizer, messages):
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
for ignore_eos in [True, False]:
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
# NOTE coordinator will set this to skip detokenization
"detokenize": False,
"ignore_eos": ignore_eos,
},
"stream": False,
}
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
payload = {
"model": MODEL_NAME,
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"ignore_eos": ignore_eos,
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res
@pytest.mark.asyncio
async def test_stop_string_workflow(client, tokenizer, messages):
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
payload = {
"model": MODEL_NAME,
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
"detokenize": False,
# stop strings are only supported when detokenize is True.
"stop": ["27 member"],
},
# TODO stream test is much more interesting
"stream": False,
}
with pytest.raises(httpx.HTTPStatusError):
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_resp.raise_for_status()
payload["sampling_params"]["stop"] = None
generate_resp = await client.post(
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
# NOTE This is under the responsibility of the coordinator
# stop_checker = StopChecker(
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
# )
stop_str, truncate_to = check_stop_strings(
generate_res, len(generate_res), ["27 member"], False
)
assert stop_str == "27 member"
# abort request that hit stop string (requires tokens-only mode)
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
# res.raise_for_status()
generate_res = generate_res[:truncate_to]
# Get stop_str response from chat completions
payload = {
"model": MODEL_NAME,
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"stop": ["27 member"],
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res
@pytest.mark.asyncio
@pytest.mark.parametrize(
"server",
[
[
"--enable-lora",
"--lora-modules",
"Alice=charent/self_cognition_Alice",
"Bob=charent/self_cognition_Bob",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
]
],
indirect=True,
)
async def test_generate_with_lora_adapter(client, tokenizer, messages):
# Verify adapters are listed
models_resp = await client.get("/v1/models")
models_resp.raise_for_status()
models = {m["id"] for m in models_resp.json().get("data", [])}
assert {"Alice", "Bob"}.issubset(models)
# Generate using a LoRA adapter by specifying its name as the model
payload = {
"model": "Alice",
"token_ids": [1, 2, 3],
"sampling_params": {"max_tokens": 5},
"stream": False,
}
resp = await client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
assert "choices" in data
token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
enable_thinking=False, # default with Qwen3
)
payload = {
"model": "Alice",
"token_ids": token_ids,
"sampling_params": {
"max_tokens": 24,
"temperature": 0.0,
"detokenize": False,
},
"stream": False,
}
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_data = generate_resp.json()
generate_res = tokenizer.decode(
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
)
payload = {
"model": "Alice",
"messages": messages,
"max_tokens": 24,
"temperature": 0.0,
"stream": False,
"chat_template_kwargs": dict(enable_thinking=False),
}
completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"]
assert generate_res == completions_res

View File

@ -566,6 +566,7 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend | None = ( kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend CacheConfig.kv_offloading_backend
) )
tokens_only: bool = False
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
@ -1495,6 +1496,10 @@ class EngineArgs:
else ParallelConfig.data_parallel_rpc_port else ParallelConfig.data_parallel_rpc_port
) )
if self.tokens_only and not model_config.skip_tokenizer_init:
model_config.skip_tokenizer_init = True
logger.info("Skipping tokenizer initialization for tokens-only mode.")
# Forward the deprecated CLI args to the EPLB config. # Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None: if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts self.eplb_config.num_redundant_experts = self.num_redundant_experts

View File

@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse, EmbeddingResponse,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
GenerateRequest,
GenerateResponse,
IOProcessorResponse, IOProcessorResponse,
PoolingBytesResponse, PoolingBytesResponse,
PoolingRequest, PoolingRequest,
@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokens import ServingTokens
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription, OpenAIServingTranscription,
OpenAIServingTranslation, OpenAIServingTranslation,
@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client return request.app.state.engine_client
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
@router.get("/health", response_class=Response) @router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response: async def health(raw_request: Request) -> Response:
"""Health check.""" """Health check."""
@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
] ]
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once( logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be " "Torch Profiler is enabled in the API server. This should ONLY be "
@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
) )
app = sagemaker_standards.bootstrap(app) app = sagemaker_standards.bootstrap(app)
# Optional endpoints
if args.tokens_only:
@app.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
return app return app
@ -1851,6 +1918,20 @@ async def init_app_state(
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
) )
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)
state.enable_server_load_tracking = args.enable_server_load_tracking state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0 state.server_load_metrics = 0

View File

@ -189,6 +189,11 @@ class FrontendArgs:
Helps mitigate header abuse. Default: 256.""" Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses""" """If set to True, log the stack trace of error responses"""
tokens_only: bool = False
"""
If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup.
"""
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@ -3220,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel):
words: list[TranslationWord] | None = None words: list[TranslationWord] | None = None
"""Extracted words and their corresponding timestamps.""" """Extracted words and their corresponding timestamps."""
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
class GenerateResponseChoice(BaseModel):
index: int
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
token_ids: list[int] | None = None
class GenerateResponse(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices: list[GenerateResponseChoice]
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)

View File

@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
FunctionCall, FunctionCall,
FunctionDefinition, FunctionDefinition,
GenerateRequest,
GenerateResponse,
IOProcessorRequest, IOProcessorRequest,
PoolingResponse, PoolingResponse,
RerankRequest, RerankRequest,
@ -134,6 +136,7 @@ AnyRequest: TypeAlias = (
| SpeechToTextRequest | SpeechToTextRequest
| ResponsesRequest | ResponsesRequest
| IOProcessorRequest | IOProcessorRequest
| GenerateRequest
) )
AnyResponse: TypeAlias = ( AnyResponse: TypeAlias = (
@ -145,6 +148,7 @@ AnyResponse: TypeAlias = (
| PoolingResponse | PoolingResponse
| ClassificationResponse | ClassificationResponse
| ScoreResponse | ScoreResponse
| GenerateResponse
) )

View File

@ -0,0 +1,269 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
from collections.abc import Sequence as GenericSequence
from fastapi import Request
# yapf: disable
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class ServingTokens(OpenAIServing):
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_prompt_tokens_details: bool = False,
enable_log_outputs: bool = False,
):
super().__init__(engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
if force_no_detokenize:
logger.info("Tokens-only mode is enabled, skipping detokenization "
"step for incoming requests.")
async def serve_tokens(
self,
request: GenerateRequest,
raw_request: Request | None = None
) -> GenerateResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
lora_request = None
lora_request = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
request_id = "generate-tokens-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
try:
sampling_params = request.sampling_params
if self.force_no_detokenize:
sampling_params.detokenize = False
self._log_inputs(request_id,
request.token_ids,
params=sampling_params,
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
except ValueError as e:
return self.create_error_response(str(e))
# TODO(NickLucche): Implement streaming response
try:
assert result_generator is not None
return await self.serve_tokens_full_generator(
request, result_generator, request_id, model_name,
request_metadata)
except ValueError as e:
return self.create_error_response(str(e))
async def serve_tokens_full_generator(
self,
request: GenerateRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str,
model_name: str,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | GenerateResponse:
created_time = int(time.time())
final_res: RequestOutput | None = None
sampling_params: SamplingParams = request.sampling_params
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
assert final_res is not None
choices: list[GenerateResponseChoice] = []
num_generated_tokens = 0
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
# This is top_logprobs in completions API
if sampling_params.logprobs:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_tokens_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=sampling_params.logprobs,
)
else:
logprobs = None
choice_data = GenerateResponseChoice(
index=output.index,
logprobs=logprobs,
finish_reason=output.finish_reason
if output.finish_reason else "stop",
token_ids=as_list(output.token_ids))
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
# This info is not available at the /coordinator level
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
response = GenerateResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[
choice.index].token_ids
if output_token_ids:
# Log token_ids only.
self.request_logger.log_outputs(
request_id=request_id,
outputs="",
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response
def _create_tokens_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int | None = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content: list[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
token = f"token_id:{token_id}"
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None or step_top_logprobs.get(
token_id) is None:
logprobs_content.append(
ChatCompletionLogProbsContent(token=token, ))
else:
step_token = step_top_logprobs[token_id]
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
logprob=max(step_token.logprob, -9999.0),
top_logprobs=[
ChatCompletionLogProb(
token=token,
logprob=max(p[1].logprob, -9999.0),
) for i, p in enumerate(step_top_logprobs.items())
if num_output_top_logprobs
and i < num_output_top_logprobs
]))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__) logger = init_logger(__name__)
@ -122,6 +123,7 @@ class RequestOutputKind(Enum):
class SamplingParams( class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property. # required for @cached_property.

View File

@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason, # These are possible values of RequestOutput.finish_reason,
# so form part of the external API. # so form part of the external API.
@ -131,13 +132,6 @@ class EngineCoreOutput(
return self.finish_reason is not None return self.finish_reason is not None
class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""
def __init__(self, r: Any = None):
self.result = r
class UtilityOutput( class UtilityOutput(
msgspec.Struct, msgspec.Struct,
array_like=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg]

View File

@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
from types import FunctionType from types import FunctionType
from typing import Any, TypeAlias from typing import Any, TypeAlias, get_type_hints
import cloudpickle import cloudpickle
import msgspec import msgspec
@ -16,6 +16,8 @@ import numpy as np
import torch import torch
import zmq import zmq
from msgspec import msgpack from msgspec import msgpack
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -32,7 +34,6 @@ from vllm.multimodal.inputs import (
NestedTensors, NestedTensors,
) )
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data from vllm.v1.utils import tensor_data
logger = init_logger(__name__) logger = init_logger(__name__)
@ -104,6 +105,13 @@ def _decode_type_info_recursive(
return convert_fn(type_info, data) return convert_fn(type_info, data)
class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""
def __init__(self, r: Any = None):
self.result = r
class MsgpackEncoder: class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization. """Encoder with custom torch tensor and numpy array serialization.
@ -469,3 +477,56 @@ def run_method(
else: else:
func = partial(method, obj) # type: ignore func = partial(method, obj) # type: ignore
return func(*args, **kwargs) return func(*args, **kwargs)
class PydanticMsgspecMixin:
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Make msgspec.Struct compatible with Pydantic, respecting defaults.
Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
API as input or in `/docs`. Note this is cached by Pydantic and not
called on every validation.
"""
msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
type_hints = get_type_hints(source_type)
# Build the Pydantic typed_dict_field for each msgspec field
fields = {}
for name, hint in type_hints.items():
msgspec_field = msgspec_fields[name]
# typed_dict_field using the handler to get the schema
field_schema = handler(hint)
# Add default value to the schema.
if msgspec_field.default_factory is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default_factory=msgspec_field.default_factory,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
elif msgspec_field.default is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default=msgspec_field.default,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
else:
# No default, so Pydantic will treat it as required
fields[name] = core_schema.typed_dict_field(field_schema)
return core_schema.no_info_after_validator_function(
cls._validate_msgspec,
core_schema.typed_dict_schema(fields),
)
@classmethod
def _validate_msgspec(cls, value: Any) -> Any:
"""Validate and convert input to msgspec.Struct instance."""
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(**value)
return msgspec.convert(value, type=cls)