diff --git a/examples/online_serving/token_generation_client.py b/examples/online_serving/token_generation_client.py new file mode 100644 index 0000000000000..88ee43c5d9cdf --- /dev/null +++ b/examples/online_serving/token_generation_client.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import httpx +from transformers import AutoTokenizer + +GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate" +DUMMY_API_KEY = "empty" +MODEL_NAME = "Qwen/Qwen3-0.6B" + +transport = httpx.HTTPTransport() +headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"} +client = httpx.Client( + transport=transport, + base_url=GEN_ENDPOINT, + timeout=600, + headers=headers, +) +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many countries are in the EU?"}, +] + + +def main(client): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, + ) + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False}, + "stream": False, + } + resp = client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + print(data) + print("-" * 50) + print("Token generation results:") + res = tokenizer.decode(data["choices"][0]["token_ids"]) + print(res) + print("-" * 50) + + +if __name__ == "__main__": + main(client) diff --git a/requirements/docs.txt b/requirements/docs.txt index 0fd6dbe22c512..32e004b2b64ba 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -10,3 +10,7 @@ mkdocs-minify-plugin regex ruff pydantic + +# For generating argparse docs. +# Adding requirements here should only be used as a last resort. +msgspec # Need for multiple inheritance involving msgspec.Struct \ No newline at end of file diff --git a/tests/entrypoints/openai/test_serving_tokens.py b/tests/entrypoints/openai/test_serving_tokens.py new file mode 100644 index 0000000000000..62d843e35b86f --- /dev/null +++ b/tests/entrypoints/openai/test_serving_tokens.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import httpx +import pytest +import pytest_asyncio +from transformers import AutoTokenizer + +from vllm.config import ModelConfig +from vllm.v1.engine.detokenizer import check_stop_strings + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +GEN_ENDPOINT = "/inference/v1/generate" + + +def get_vocab_size(model_name): + config = ModelConfig( + model=model_name, + seed=0, + dtype="bfloat16", + ) + return config.get_vocab_size() + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def messages(): + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many countries are in the EU?"}, + ] + + +@pytest.fixture(scope="module") +def server(request): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "1024", + "--enforce-eager", + ] + + extra_args = getattr(request, "param", None) + if extra_args is not None: + args = args + ( + list(extra_args) + if isinstance(extra_args, (list, tuple)) + else [str(extra_args)] + ) + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + async with httpx.AsyncClient( + transport=transport, + base_url=server.url_root, + timeout=600, + headers=headers, + ) as c: + yield c + + +@pytest.mark.asyncio +async def test_generate_endpoint(client): + payload = { + "model": MODEL_NAME, + "token_ids": [1, 2, 3], + "sampling_params": {"max_tokens": 5}, + "stream": False, + } + resp = await client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + assert "choices" in data + + +@pytest.mark.asyncio +async def test_same_response_as_chat_completions(client, tokenizer, messages): + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + for ignore_eos in [True, False]: + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + # NOTE coordinator will set this to skip detokenization + "detokenize": False, + "ignore_eos": ignore_eos, + }, + "stream": False, + } + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + payload = { + "model": MODEL_NAME, + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "ignore_eos": ignore_eos, + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + + assert generate_res == completions_res + + +@pytest.mark.asyncio +async def test_stop_string_workflow(client, tokenizer, messages): + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + payload = { + "model": MODEL_NAME, + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + "detokenize": False, + # stop strings are only supported when detokenize is True. + "stop": ["27 member"], + }, + # TODO stream test is much more interesting + "stream": False, + } + with pytest.raises(httpx.HTTPStatusError): + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_resp.raise_for_status() + + payload["sampling_params"]["stop"] = None + generate_resp = await client.post( + GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"} + ) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + # NOTE This is under the responsibility of the coordinator + # stop_checker = StopChecker( + # max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer + # ) + stop_str, truncate_to = check_stop_strings( + generate_res, len(generate_res), ["27 member"], False + ) + assert stop_str == "27 member" + # abort request that hit stop string (requires tokens-only mode) + # res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501 + # res.raise_for_status() + generate_res = generate_res[:truncate_to] + + # Get stop_str response from chat completions + payload = { + "model": MODEL_NAME, + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "stop": ["27 member"], + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + assert generate_res == completions_res + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + [ + "--enable-lora", + "--lora-modules", + "Alice=charent/self_cognition_Alice", + "Bob=charent/self_cognition_Bob", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + ] + ], + indirect=True, +) +async def test_generate_with_lora_adapter(client, tokenizer, messages): + # Verify adapters are listed + models_resp = await client.get("/v1/models") + models_resp.raise_for_status() + models = {m["id"] for m in models_resp.json().get("data", [])} + assert {"Alice", "Bob"}.issubset(models) + + # Generate using a LoRA adapter by specifying its name as the model + payload = { + "model": "Alice", + "token_ids": [1, 2, 3], + "sampling_params": {"max_tokens": 5}, + "stream": False, + } + resp = await client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + assert "choices" in data + + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=False, # default with Qwen3 + ) + payload = { + "model": "Alice", + "token_ids": token_ids, + "sampling_params": { + "max_tokens": 24, + "temperature": 0.0, + "detokenize": False, + }, + "stream": False, + } + generate_resp = await client.post(GEN_ENDPOINT, json=payload) + generate_data = generate_resp.json() + generate_res = tokenizer.decode( + generate_data["choices"][0]["token_ids"], skip_special_tokens=True + ) + + payload = { + "model": "Alice", + "messages": messages, + "max_tokens": 24, + "temperature": 0.0, + "stream": False, + "chat_template_kwargs": dict(enable_thinking=False), + } + completions_resp = await client.post("/v1/chat/completions", json=payload) + completions_data = completions_resp.json() + completions_res = completions_data["choices"][0]["message"]["content"] + + assert generate_res == completions_res diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cacebc530b6ee..999ed780c20bf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -566,6 +566,7 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend | None = ( CacheConfig.kv_offloading_backend ) + tokens_only: bool = False def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1495,6 +1496,10 @@ class EngineArgs: else ParallelConfig.data_parallel_rpc_port ) + if self.tokens_only and not model_config.skip_tokenizer_init: + model_config.skip_tokenizer_init = True + logger.info("Skipping tokenizer initialization for tokens-only mode.") + # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: self.eplb_config.num_redundant_experts = self.num_redundant_experts diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f30c6ef2cd0a4..3e59af717d95c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import ( EmbeddingResponse, ErrorInfo, ErrorResponse, + GenerateRequest, + GenerateResponse, IOProcessorResponse, PoolingBytesResponse, PoolingRequest, @@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization +from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, @@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client +def generate_tokens(request: Request) -> ServingTokens | None: + return request.app.state.serving_tokens + + @router.get("/health", response_class=Response) async def health(raw_request: Request) -> Response: """Health check.""" @@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [ ] +@router.post( + "/inference/v1/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def generate(request: GenerateRequest, raw_request: Request): + handler = generate_tokens(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support generate tokens API" + ) + try: + generator = await handler.serve_tokens(request, raw_request) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + + elif isinstance(generator, GenerateResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI: ) app = sagemaker_standards.bootstrap(app) + # Optional endpoints + if args.tokens_only: + + @app.post("/abort_requests") + async def abort_requests(raw_request: Request): + """ + Abort one or more requests. To be used in a + Disaggregated Everything setup. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + request_ids = body.get("request_ids") + if request_ids is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'request_ids' in request body", + ) + # Abort requests in background + asyncio.create_task(engine_client(raw_request).abort(request_ids)) + return Response(status_code=200) return app @@ -1851,6 +1918,20 @@ async def init_app_state( if "generate" in supported_tasks else None ) + state.serving_tokens = ( + ServingTokens( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + log_error_stack=args.log_error_stack, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_log_outputs=args.enable_log_outputs, + force_no_detokenize=args.tokens_only, + ) + if "generate" in supported_tasks + else None + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 476587c178237..946362ce2ef0a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -189,6 +189,11 @@ class FrontendArgs: Helps mitigate header abuse. Default: 256.""" log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE """If set to True, log the stack trace of error responses""" + tokens_only: bool = False + """ + If set to True, only enable the Tokens In<>Out endpoint. + This is intended for use in a Disaggregated Everything setup. + """ @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 45584df8b9e26..65bd15ba387b9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3220,3 +3220,80 @@ class TranslationResponseVerbose(OpenAIBaseModel): words: list[TranslationWord] | None = None """Extracted words and their corresponding timestamps.""" + + +####### Tokens IN <> Tokens OUT ####### +class GenerateRequest(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + token_ids: list[int] + """The token ids to generate text from.""" + + # features: MultiModalFeatureSpec + # TODO (NickLucche): implement once Renderer work is completed + features: str | None = None + """The processed MM inputs for the model.""" + + sampling_params: SamplingParams + """The sampling parameters for the model.""" + + model: str | None = None + + stream: bool | None = False + stream_options: StreamOptions | None = None + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + +class GenerateResponseChoice(BaseModel): + index: int + logprobs: ChatCompletionLogProbs | None = None + # per OpenAI spec this is the default + finish_reason: str | None = "stop" + token_ids: list[int] | None = None + + +class GenerateResponse(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + choices: list[GenerateResponseChoice] + + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 03f10e5a91e64..c50b0c4a23e17 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import ( ErrorResponse, FunctionCall, FunctionDefinition, + GenerateRequest, + GenerateResponse, IOProcessorRequest, PoolingResponse, RerankRequest, @@ -134,6 +136,7 @@ AnyRequest: TypeAlias = ( | SpeechToTextRequest | ResponsesRequest | IOProcessorRequest + | GenerateRequest ) AnyResponse: TypeAlias = ( @@ -145,6 +148,7 @@ AnyResponse: TypeAlias = ( | PoolingResponse | ClassificationResponse | ScoreResponse + | GenerateResponse ) diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/openai/serving_tokens.py new file mode 100644 index 0000000000000..69a526b9b70d2 --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokens.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import time +from collections.abc import AsyncGenerator +from collections.abc import Sequence as GenericSequence + +from fastapi import Request + +# yapf: disable +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ErrorResponse, + GenerateRequest, + GenerateResponse, + GenerateResponseChoice, + PromptTokenUsageInfo, + RequestResponseMetadata, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.logger import init_logger +from vllm.logprobs import Logprob +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.utils.collection_utils import as_list + +logger = init_logger(__name__) + + +class ServingTokens(OpenAIServing): + """Provides Tokens IN <> Tokens OUT functionality to vLLM API.""" + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + force_no_detokenize: bool = False, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + enable_prompt_tokens_details: bool = False, + enable_log_outputs: bool = False, + ): + super().__init__(engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_log_outputs = enable_log_outputs + self.force_no_detokenize = force_no_detokenize + if force_no_detokenize: + logger.info("Tokens-only mode is enabled, skipping detokenization " + "step for incoming requests.") + + async def serve_tokens( + self, + request: GenerateRequest, + raw_request: Request | None = None + ) -> GenerateResponse | ErrorResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + lora_request = None + lora_request = self._maybe_get_adapters(request, + supports_default_mm_loras=True) + + model_name = self.models.model_name(lora_request) + + request_id = "generate-tokens-" \ + f"{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is + # completed + engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) + if request.features is not None: + engine_prompt["multi_modal_data"] = None + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + # Schedule the request and get the result generator. + result_generator: AsyncGenerator[RequestOutput, None] | None = None + try: + sampling_params = request.sampling_params + if self.force_no_detokenize: + sampling_params.detokenize = False + + self._log_inputs(request_id, + request.token_ids, + params=sampling_params, + lora_request=lora_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + result_generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + except ValueError as e: + return self.create_error_response(str(e)) + + # TODO(NickLucche): Implement streaming response + + try: + assert result_generator is not None + return await self.serve_tokens_full_generator( + request, result_generator, request_id, model_name, + request_metadata) + except ValueError as e: + return self.create_error_response(str(e)) + + async def serve_tokens_full_generator( + self, + request: GenerateRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, + model_name: str, + request_metadata: RequestResponseMetadata, + ) -> ErrorResponse | GenerateResponse: + + created_time = int(time.time()) + final_res: RequestOutput | None = None + sampling_params: SamplingParams = request.sampling_params + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[GenerateResponseChoice] = [] + num_generated_tokens = 0 + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + # This is top_logprobs in completions API + if sampling_params.logprobs: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_tokens_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=sampling_params.logprobs, + ) + else: + logprobs = None + + choice_data = GenerateResponseChoice( + index=output.index, + logprobs=logprobs, + finish_reason=output.finish_reason + if output.finish_reason else "stop", + token_ids=as_list(output.token_ids)) + + choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + # This info is not available at the /coordinator level + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = GenerateResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + ) + + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[ + choice.index].token_ids + + if output_token_ids: + # Log token_ids only. + self.request_logger.log_outputs( + request_id=request_id, + outputs="", + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + + return response + + def _create_tokens_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[dict[int, Logprob] | None], + num_output_top_logprobs: int | None = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: list[ChatCompletionLogProbsContent] = [] + + for i, token_id in enumerate(token_ids): + token = f"token_id:{token_id}" + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None or step_top_logprobs.get( + token_id) is None: + logprobs_content.append( + ChatCompletionLogProbsContent(token=token, )) + else: + step_token = step_top_logprobs[token_id] + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + logprob=max(step_token.logprob, -9999.0), + top_logprobs=[ + ChatCompletionLogProb( + token=token, + logprob=max(p[1].logprob, -9999.0), + ) for i, p in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs + and i < num_output_top_logprobs + ])) + + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4b2a3bc4dbaa6..dd820840410ed 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -15,6 +15,7 @@ from pydantic.dataclasses import dataclass from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.serial_utils import PydanticMsgspecMixin logger = init_logger(__name__) @@ -122,6 +123,7 @@ class RequestOutputKind(Enum): class SamplingParams( + PydanticMsgspecMixin, msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 058a4bcaecb58..3f621d77c0241 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -15,6 +15,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors +from vllm.v1.serial_utils import UtilityResult # These are possible values of RequestOutput.finish_reason, # so form part of the external API. @@ -131,13 +132,6 @@ class EngineCoreOutput( return self.finish_reason is not None -class UtilityResult: - """Wrapper for special handling when serializing/deserializing.""" - - def __init__(self, r: Any = None): - self.result = r - - class UtilityOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index cf0b1a41b50f8..0a6806390451d 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence from functools import partial from inspect import isclass from types import FunctionType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, get_type_hints import cloudpickle import msgspec @@ -16,6 +16,8 @@ import numpy as np import torch import zmq from msgspec import msgpack +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema from vllm import envs from vllm.logger import init_logger @@ -32,7 +34,6 @@ from vllm.multimodal.inputs import ( NestedTensors, ) from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.engine import UtilityResult from vllm.v1.utils import tensor_data logger = init_logger(__name__) @@ -104,6 +105,13 @@ def _decode_type_info_recursive( return convert_fn(type_info, data) +class UtilityResult: + """Wrapper for special handling when serializing/deserializing.""" + + def __init__(self, r: Any = None): + self.result = r + + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. @@ -469,3 +477,56 @@ def run_method( else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) + + +class PydanticMsgspecMixin: + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """ + Make msgspec.Struct compatible with Pydantic, respecting defaults. + Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the + API as input or in `/docs`. Note this is cached by Pydantic and not + called on every validation. + """ + msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)} + type_hints = get_type_hints(source_type) + + # Build the Pydantic typed_dict_field for each msgspec field + fields = {} + for name, hint in type_hints.items(): + msgspec_field = msgspec_fields[name] + + # typed_dict_field using the handler to get the schema + field_schema = handler(hint) + + # Add default value to the schema. + if msgspec_field.default_factory is not msgspec.NODEFAULT: + wrapped_schema = core_schema.with_default_schema( + schema=field_schema, + default_factory=msgspec_field.default_factory, + ) + fields[name] = core_schema.typed_dict_field(wrapped_schema) + elif msgspec_field.default is not msgspec.NODEFAULT: + wrapped_schema = core_schema.with_default_schema( + schema=field_schema, + default=msgspec_field.default, + ) + fields[name] = core_schema.typed_dict_field(wrapped_schema) + else: + # No default, so Pydantic will treat it as required + fields[name] = core_schema.typed_dict_field(field_schema) + return core_schema.no_info_after_validator_function( + cls._validate_msgspec, + core_schema.typed_dict_schema(fields), + ) + + @classmethod + def _validate_msgspec(cls, value: Any) -> Any: + """Validate and convert input to msgspec.Struct instance.""" + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls(**value) + return msgspec.convert(value, type=cls)