diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 246bd014aa690..11ed1c4a9ee4b 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -74,31 +74,44 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): -d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \ http://localhost:8000/v1/chat/completions """ # noqa: E501 - if (hasattr(case, "body") and isinstance(case.body, dict) - and "messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): + if hasattr(case, "body") and isinstance(case.body, dict): + if ("messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): - for message in case.body["messages"]: - if not isinstance(message, dict): - continue + for message in case.body["messages"]: + if not isinstance(message, dict): + continue - # Check for invalid file type in tokenize endpoint - if op.method.lower() == "post" and op.path == "/tokenize": - content = message.get("content", []) - if (isinstance(content, list) and len(content) > 0 and any( - item.get("type") == "file" for item in content)): - return False + # Check for invalid file type in tokenize endpoint + if op.method.lower() == "post" and op.path == "/tokenize": + content = message.get("content", []) + if (isinstance(content, list) and len(content) > 0 + and any( + item.get("type") == "file" + for item in content)): + return False + + # Check for invalid tool_calls with non-function types + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict): + if tool_call.get("type") != "function": + return False + if "custom" in tool_call: + return False + + # Sometimes guided_grammar is generated to be empty + # Causing a server error in EBNF grammar parsing + # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 + guided_grammar = case.body.get("guided_grammar") + + if guided_grammar == '': + # Allow None (will be handled as no grammar) + # But skip empty strings + return False - # Check for invalid tool_calls with non-function types - tool_calls = message.get("tool_calls", []) - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if isinstance(tool_call, dict): - if tool_call.get("type") != "function": - return False - if "custom" in tool_call: - return False return True return strategy.filter(no_invalid_types) diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py new file mode 100644 index 0000000000000..6addcb41c4098 --- /dev/null +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enforce-eager", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_basic_completion_with_emoji(server): + """Test basic completion with emoji to verify token_ids field.""" + async with server.get_async_client() as client: + # Test with return_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": True}, + ) + + # Check the raw response to see the structure + completion_dict = completion.model_dump() + + # Verify prompt_token_ids field is present in the completion response + assert "prompt_token_ids" in completion_dict["choices"][0] + assert isinstance(completion.choices[0].prompt_token_ids, list) + + # Check against the expected prompt token IDs + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + encoded_tokens = tokenizer.encode( + "Complete this sentence with emojis: I love coding 🚀") + # Check that encoded_tokens is a subsequence of prompt_token_ids + assert any(completion.choices[0].prompt_token_ids[i:i + + len(encoded_tokens)] + == encoded_tokens for i in range( + len(completion.choices[0].prompt_token_ids) - + len(encoded_tokens) + 1)) + + # Verify token_ids field is present in the choice + assert completion.choices[0].token_ids is not None + assert isinstance(completion.choices[0].token_ids, list) + assert len(completion.choices[0].token_ids) > 0 + + # Verify decoding works correctly + decoded_text = tokenizer.decode(completion.choices[0].token_ids) + # The decoded text should contain a <|im_end|> at the end + assert decoded_text.startswith(completion.choices[0].text) + + # Test without return_token_ids (should be None) + completion_without = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": False}, + ) + + completion_without_dict = completion_without.model_dump() + assert completion_without_dict["choices"][0].get("token_ids") is None + assert completion_without_dict.get("prompt_token_ids") is None + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_use(server): + """Test chat completion with tool use (get_weather function).""" + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, + }, + "required": ["location"], + }, + }, + }] + + async with server.get_async_client() as client: + # Test with return_token_ids enabled + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids field is present in choices + assert response.choices[0].token_ids is not None + assert isinstance(response.choices[0].token_ids, list) + + # Verify prompt_token_ids field is present + assert response.prompt_token_ids is not None + assert isinstance(response.prompt_token_ids, list) + + # Verify the prompt texts and response texts + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + prompt_text = tokenizer.decode(response.prompt_token_ids) + assert prompt_text.startswith( + "<|im_start|>system\nYou are a helpful assistant.") + assert prompt_text.endswith( + "What's the weather like in Paris?<|im_end|>\n" + "<|im_start|>assistant\n") + + response_text = tokenizer.decode(response.choices[0].token_ids) + assert response_text.startswith('\n{"name": "get_weather"') + assert response_text.endswith("<|im_end|>") + + # If tool call was made, verify the response structure + if response.choices[0].message.tool_calls: + assert len(response.choices[0].message.tool_calls) > 0 + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.function.name == "get_weather" + + # Test without return_token_ids + response_without = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": False}, + ) + + assert response_without.choices[0].token_ids is None + assert response_without.prompt_token_ids is None + + +@pytest.mark.asyncio +async def test_comparison_with_prompt_logprobs_and_logprobs(server): + """ + Test that token_ids align with prompt_logprobs and + logprobs when return_tokens_as_token_ids is enabled. + """ + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Hello, world! How are you today?", + max_tokens=20, + temperature=0, + echo=True, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True, + "prompt_logprobs": 1 + }, + ) + + # Verify all fields are present + assert completion.choices[0].token_ids is not None + assert completion.choices[0].prompt_token_ids is not None + assert completion.choices[0].prompt_logprobs is not None + assert completion.choices[0].logprobs is not None + + # Extract token IDs from logprobs + # (when return_tokens_as_token_ids is True) + logprobs_token_ids = [] + for token_str in completion.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + logprobs_token_ids.append(token_id) + + # When echo=True, the logprobs include both prompt and response tokens + # The token_ids field should match the the suffix of response portion + # The prompt_token_ids should match the prompt portion + assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) + response_token_ids_length = len(completion.choices[0].token_ids) + assert logprobs_token_ids[-response_token_ids_length:] == \ + completion.choices[0].token_ids + + # Verify tokenizer consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert "Hello, world" in prompt_text + + # Decode response tokens + if completion.choices[0].token_ids: + response_text = tokenizer.decode(completion.choices[0].token_ids) + assert completion.choices[0].text.endswith(response_text) + + # Test streaming mode + stream = await client.completions.create( + model=MODEL_NAME, + prompt="Tell me a short fact about Python:", + max_tokens=30, + temperature=0, + stream=True, + echo=False, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True + }, + ) + + # Collect streamed tokens + streamed_prompt_token_ids = [] + streamed_token_ids = [] + streamed_logprob_token_ids = [] + first_chunk = True + async for chunk in stream: + for token_str in chunk.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + streamed_logprob_token_ids.append(token_id) + if first_chunk: + streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids + first_chunk = False + streamed_token_ids += chunk.choices[0].token_ids + + # Verify we collected some tokens and first chunk had prompt_token_ids + assert len(streamed_prompt_token_ids) > 0 + assert streamed_token_ids == streamed_logprob_token_ids + + +@pytest.mark.asyncio +async def test_chat_completion_with_emoji_and_token_ids(server): + """Test chat completion with emojis to verify token_ids handling.""" + chat_messages = [ + { + "role": "system", + "content": "You like to use emojis in your responses." + }, + { + "role": "user", + "content": "Repeat after me: I love cats 🐱" + }, + ] + async with server.get_async_client() as client: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids are present + response_dict = response.model_dump() + assert response.choices[0].token_ids is not None + assert "prompt_token_ids" in response_dict + + # Verify the response contains the expected fields + assert response.choices[0].message.content is not None + + # Decode token_ids and verify consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + decoded_prompt = tokenizer.decode(response.prompt_token_ids) + assert decoded_prompt.startswith( + "<|im_start|>system\nYou like to use emojis in your responses.") + assert decoded_prompt.endswith( + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + + decoded_response = tokenizer.decode(response.choices[0].token_ids) + # The content should match the response text + # except the ending <|im_end|> + assert decoded_response == response.choices[ + 0].message.content + "<|im_end|>" + + # Test with streaming + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + stream=True, + extra_body={"return_token_ids": True}, + ) + + collected_content = "" + collected_token_ids = [] + first_chunk = True + + async for chunk in stream: + if first_chunk: + assert chunk.prompt_token_ids is not None + assert isinstance(chunk.prompt_token_ids, list) + # Check the prompt_token_ids match the initial prompt + decoded_prompt_stream = tokenizer.decode( + chunk.prompt_token_ids) + assert decoded_prompt_stream == decoded_prompt + first_chunk = False + else: + chunk_dump = chunk.model_dump() + assert "prompt_token_ids" not in chunk_dump, \ + "Subsequent chunks should not have prompt_token_ids" + + if chunk.choices: + if chunk.choices[0].delta.content: + collected_content += chunk.choices[0].delta.content + # token_ids may not present in all chunks + choice_dump = chunk.choices[0].model_dump() + if "token_ids" in choice_dump: + collected_token_ids.extend(chunk.choices[0].token_ids) + + # Verify we got response and token_ids + assert len(collected_content) > 0 + assert len(collected_token_ids) > 0 + + # Verify token_ids decode properly + decoded_response = tokenizer.decode(collected_token_ids) + assert decoded_response == collected_content + "<|im_end|>" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 61f1a09d3ac1c..39facd4d53d32 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -576,6 +576,14 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, description=( @@ -1062,6 +1070,14 @@ class CompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, @@ -1480,7 +1496,9 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None # For prompt class CompletionResponse(OpenAIBaseModel): @@ -1511,6 +1529,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + # not part of the OpenAI spec but for tracing the tokens + # prompt tokens is put into choice to align with CompletionResponseChoice + prompt_token_ids: Optional[list[int]] = None + token_ids: Optional[list[int]] = None class CompletionStreamResponse(OpenAIBaseModel): @@ -1680,6 +1702,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but is useful for tracing the tokens + # in agent scenarios + token_ids: Optional[list[int]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -1695,6 +1720,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1712,6 +1738,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but for tracing the tokens + token_ids: Optional[list[int]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -1721,6 +1749,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: list[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + # not part of the OpenAI spec but for tracing the tokens + prompt_token_ids: Optional[list[int]] = None class TranscriptionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 12349234c320f..1789521afc84c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -568,12 +568,17 @@ class OpenAIServingChat(OpenAIServing): ), logprobs=None, finish_reason=None) + + # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + prompt_token_ids=(res.prompt_token_ids + if request.return_token_ids else + None)) # if continuous usage stats are requested, add it if include_continuous_usage: @@ -912,7 +917,9 @@ class OpenAIServingChat(OpenAIServing): index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None) + finish_reason=None, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) # if the model is finished generating else: @@ -973,7 +980,9 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) finish_reason_sent[i] = True @@ -1260,7 +1269,10 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), + ) choices.append(choice_data) @@ -1301,6 +1313,8 @@ class OpenAIServingChat(OpenAIServing): choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + prompt_token_ids=(final_res.prompt_token_ids + if request.return_token_ids else None), kv_transfer_params=final_res.kv_transfer_params, ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 22c6b6250394c..a0ce654094039 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -42,7 +42,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators +from vllm.utils import as_list, merge_async_iterators logger = init_logger(__name__) @@ -365,6 +365,11 @@ class OpenAIServingCompletion(OpenAIServing): for output in res.outputs: i = output.index + prompt_idx * num_choices + # Useful when request.return_token_ids is True + # Returning prompt token IDs shares the same logic + # with the echo implementation. + prompt_token_ids_to_return: Optional[list[int]] = None + assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None @@ -385,6 +390,7 @@ class OpenAIServingCompletion(OpenAIServing): *(prompt_logprobs or []), *(output.logprobs or []), ] + prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True else: # return just the delta @@ -392,6 +398,12 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = output.token_ids out_logprobs = output.logprobs + # has_echoed[i] is reused here to indicate whether + # we have already returned the prompt token IDs. + if not has_echoed[i]: + prompt_token_ids_to_return = prompt_token_ids + has_echoed[i] = True + if (not delta_text and not delta_token_ids and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks @@ -428,6 +440,9 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=finish_reason, stop_reason=stop_reason, + prompt_token_ids=prompt_token_ids_to_return, + token_ids=(as_list(output.token_ids) if + request.return_token_ids else None), ) ], ) @@ -548,6 +563,10 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, + prompt_token_ids=(prompt_token_ids + if request.return_token_ids else None), + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), ) choices.append(choice_data)