mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 02:17:56 +08:00
[RL][BugFix] Fix missing tokenizer error for token-in-token-out (#23904)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
0dc9532065
commit
4d7fe40fc0
73
tests/entrypoints/openai/test_token_in_token_out.py
Normal file
73
tests/entrypoints/openai/test_token_in_token_out.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
download_weights_from_hf)
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||||
|
MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
global MODEL_PATH
|
||||||
|
MODEL_PATH = download_weights_from_hf(
|
||||||
|
MODEL_NAME,
|
||||||
|
allow_patterns=["*"],
|
||||||
|
cache_dir=MODEL_PATH,
|
||||||
|
ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"])
|
||||||
|
args = [
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--skip-tokenizer-init",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
|
]
|
||||||
|
with RemoteOpenAIServer(MODEL_PATH, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_in_token_out_and_logprobs(server):
|
||||||
|
"""
|
||||||
|
Test token-in-token-out and token_ids align with prompt_logprobs
|
||||||
|
& logprobs when return_tokens_as_token_ids is enabled.
|
||||||
|
"""
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||||
|
text = "Hello, world! How are you today?"
|
||||||
|
token_ids = tokenizer.encode(text)
|
||||||
|
async with server.get_async_client() as client:
|
||||||
|
# Test with both return_token_ids and return_tokens_as_token_ids enabled
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_PATH,
|
||||||
|
prompt=token_ids,
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0,
|
||||||
|
echo=True,
|
||||||
|
extra_body={
|
||||||
|
"return_token_ids": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all fields are present
|
||||||
|
assert (completion.choices[0].token_ids is not None
|
||||||
|
and 0 < len(completion.choices[0].token_ids) <= 20)
|
||||||
|
assert completion.choices[0].prompt_token_ids is not None
|
||||||
|
|
||||||
|
# Decode prompt tokens
|
||||||
|
if completion.choices[0].prompt_token_ids:
|
||||||
|
prompt_text = tokenizer.decode(
|
||||||
|
completion.choices[0].prompt_token_ids)
|
||||||
|
# The decoded prompt should match or close to original prompt
|
||||||
|
assert prompt_text == text
|
||||||
@ -127,7 +127,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
lora_request = self._maybe_get_adapters(request)
|
lora_request = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
if self.model_config.skip_tokenizer_init:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||||
|
)
|
||||||
|
|
||||||
request_prompts, engine_prompts = await self._preprocess_completion(
|
request_prompts, engine_prompts = await self._preprocess_completion(
|
||||||
request,
|
request,
|
||||||
|
|||||||
@ -526,8 +526,8 @@ class OpenAIServing:
|
|||||||
async def _normalize_prompt_text_to_input(
|
async def _normalize_prompt_text_to_input(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
@ -563,12 +563,10 @@ class OpenAIServing:
|
|||||||
async def _normalize_prompt_tokens_to_input(
|
async def _normalize_prompt_tokens_to_input(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
|
||||||
prompt_ids: list[int],
|
prompt_ids: list[int],
|
||||||
|
tokenizer: Optional[AnyTokenizer],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
|
||||||
|
|
||||||
if truncate_prompt_tokens is None:
|
if truncate_prompt_tokens is None:
|
||||||
input_ids = prompt_ids
|
input_ids = prompt_ids
|
||||||
elif truncate_prompt_tokens < 0:
|
elif truncate_prompt_tokens < 0:
|
||||||
@ -576,7 +574,11 @@ class OpenAIServing:
|
|||||||
else:
|
else:
|
||||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||||
|
|
||||||
input_text = await async_tokenizer.decode(input_ids)
|
if tokenizer is None:
|
||||||
|
input_text = ""
|
||||||
|
else:
|
||||||
|
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||||
|
input_text = await async_tokenizer.decode(input_ids)
|
||||||
|
|
||||||
return self._validate_input(request, input_ids, input_text)
|
return self._validate_input(request, input_ids, input_text)
|
||||||
|
|
||||||
@ -681,27 +683,27 @@ class OpenAIServing:
|
|||||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||||
that assumes multiple inputs.
|
that assumes multiple inputs.
|
||||||
"""
|
"""
|
||||||
for text in prompt_inputs:
|
for prompt in prompt_inputs:
|
||||||
if isinstance(text, str):
|
if isinstance(prompt, str):
|
||||||
yield await self._normalize_prompt_text_to_input(
|
yield await self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
prompt=prompt,
|
||||||
prompt=text,
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield await self._normalize_prompt_tokens_to_input(
|
yield await self._normalize_prompt_tokens_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
prompt_ids=prompt,
|
||||||
prompt_ids=text,
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _tokenize_prompt_input_or_inputs_async(
|
async def _tokenize_prompt_input_or_inputs_async(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
@ -740,17 +742,19 @@ class OpenAIServing:
|
|||||||
tasks = []
|
tasks = []
|
||||||
for prompt_input in batch_inputs:
|
for prompt_input in batch_inputs:
|
||||||
if prompt_input["is_tokens"] is False:
|
if prompt_input["is_tokens"] is False:
|
||||||
|
assert tokenizer is not None, \
|
||||||
|
"Tokenizer is required for text prompts"
|
||||||
task = self._normalize_prompt_text_to_input(
|
task = self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
|
||||||
prompt_input["content"],
|
prompt_input["content"],
|
||||||
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
add_special_tokens=add_special_tokens)
|
add_special_tokens=add_special_tokens)
|
||||||
else:
|
else:
|
||||||
task = self._normalize_prompt_tokens_to_input(
|
task = self._normalize_prompt_tokens_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
|
||||||
prompt_input["content"],
|
prompt_input["content"],
|
||||||
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
@ -766,7 +770,7 @@ class OpenAIServing:
|
|||||||
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
||||||
RerankRequest, ClassificationRequest, ScoreRequest,
|
RerankRequest, ClassificationRequest, ScoreRequest,
|
||||||
TokenizeCompletionRequest],
|
TokenizeCompletionRequest],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||||
add_special_tokens: bool = ...,
|
add_special_tokens: bool = ...,
|
||||||
@ -777,7 +781,7 @@ class OpenAIServing:
|
|||||||
async def _preprocess_completion(
|
async def _preprocess_completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||||
@ -789,7 +793,7 @@ class OpenAIServing:
|
|||||||
async def _preprocess_completion(
|
async def _preprocess_completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionLikeRequest,
|
request: CompletionLikeRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user