mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 01:15:42 +08:00
[Frontend] Allow return_tokens_as_token_ids to be passed as a request param (#14066)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
dae9ec464c
commit
32985bed7c
@ -17,18 +17,28 @@ from .test_completion import MODEL_NAME
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server_with_return_tokens_as_token_ids_flag(
|
def server_fixture(request, default_server_args): # noqa: F811
|
||||||
default_server_args): # noqa: F811
|
use_server_flag = request.param
|
||||||
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
|
if use_server_flag:
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
|
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
|
||||||
yield remote_server
|
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
|
||||||
|
yield (remote_server, True)
|
||||||
|
else:
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME,
|
||||||
|
default_server_args) as remote_server:
|
||||||
|
yield (remote_server, False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
|
||||||
async def test_completion_return_tokens_as_token_ids_completion(
|
async def test_completion_return_tokens_as_token_ids_completion(
|
||||||
server_with_return_tokens_as_token_ids_flag):
|
server_fixture):
|
||||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
server, use_server_flag = server_fixture
|
||||||
) as client:
|
request_args = {}
|
||||||
|
if not use_server_flag:
|
||||||
|
request_args["return_tokens_as_token_ids"] = True
|
||||||
|
|
||||||
|
async with server.get_async_client() as client:
|
||||||
|
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
@ -39,7 +49,8 @@ async def test_completion_return_tokens_as_token_ids_completion(
|
|||||||
echo=True,
|
echo=True,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
logprobs=1)
|
logprobs=1,
|
||||||
|
extra_body=request_args)
|
||||||
|
|
||||||
text = completion.choices[0].text
|
text = completion.choices[0].text
|
||||||
token_strs = completion.choices[0].logprobs.tokens
|
token_strs = completion.choices[0].logprobs.tokens
|
||||||
@ -60,10 +71,14 @@ async def test_completion_return_tokens_as_token_ids_completion(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_return_tokens_as_token_ids_completion(
|
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
|
||||||
server_with_return_tokens_as_token_ids_flag):
|
async def test_chat_return_tokens_as_token_ids_completion(server_fixture):
|
||||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
server, use_server_flag = server_fixture
|
||||||
) as client:
|
request_args = {}
|
||||||
|
if not use_server_flag:
|
||||||
|
request_args["return_tokens_as_token_ids"] = True
|
||||||
|
|
||||||
|
async with server.get_async_client() as client:
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
# Include Unicode characters to test for dividing a single
|
# Include Unicode characters to test for dividing a single
|
||||||
@ -78,7 +93,8 @@ async def test_chat_return_tokens_as_token_ids_completion(
|
|||||||
}],
|
}],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=8,
|
max_tokens=8,
|
||||||
logprobs=True)
|
logprobs=True,
|
||||||
|
extra_body=request_args)
|
||||||
|
|
||||||
text = response.choices[0].message.content
|
text = response.choices[0].message.content
|
||||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||||
|
|||||||
@ -369,6 +369,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"arguments. For example: {'qualname': "
|
"arguments. For example: {'qualname': "
|
||||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||||
"{'param': 'value'}}."))
|
"{'param': 'value'}}."))
|
||||||
|
return_tokens_as_token_ids: Optional[bool] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified with 'logprobs', tokens are represented "
|
||||||
|
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||||
|
"that are not JSON-encodable can be identified."))
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -739,6 +745,12 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
"arguments. For example: {'qualname': "
|
"arguments. For example: {'qualname': "
|
||||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||||
"{'param': 'value'}}."))
|
"{'param': 'value'}}."))
|
||||||
|
return_tokens_as_token_ids: Optional[bool] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified with 'logprobs', tokens are represented "
|
||||||
|
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||||
|
"that are not JSON-encodable can be identified."))
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
|||||||
@ -450,6 +450,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
top_logprobs=output.logprobs,
|
top_logprobs=output.logprobs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
return_as_token_id=request.
|
||||||
|
return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -705,6 +707,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
return_as_token_id=request.return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -852,13 +855,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
def _get_top_logprobs(
|
def _get_top_logprobs(
|
||||||
self, logprobs: dict[int, Logprob], top_logprobs: Optional[int],
|
self, logprobs: dict[int, Logprob], top_logprobs: Optional[int],
|
||||||
tokenizer: AnyTokenizer) -> list[ChatCompletionLogProb]:
|
tokenizer: AnyTokenizer,
|
||||||
|
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
|
||||||
return [
|
return [
|
||||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||||
p[1],
|
p[1],
|
||||||
p[0],
|
p[0],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_as_token_id=self.return_tokens_as_token_ids)),
|
return_as_token_id=should_return_as_token_id)),
|
||||||
logprob=max(p[1].logprob, -9999.0),
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
bytes=list(
|
bytes=list(
|
||||||
token.encode("utf-8", errors="replace")))
|
token.encode("utf-8", errors="replace")))
|
||||||
@ -872,15 +876,18 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
|
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
num_output_top_logprobs: Optional[int] = None,
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
return_as_token_id: Optional[bool] = None,
|
||||||
) -> ChatCompletionLogProbs:
|
) -> ChatCompletionLogProbs:
|
||||||
"""Create OpenAI-style logprobs."""
|
"""Create OpenAI-style logprobs."""
|
||||||
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
||||||
|
|
||||||
|
should_return_as_token_id = return_as_token_id if \
|
||||||
|
return_as_token_id is not None else self.return_tokens_as_token_ids
|
||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
step_top_logprobs = top_logprobs[i]
|
step_top_logprobs = top_logprobs[i]
|
||||||
if step_top_logprobs is None:
|
if step_top_logprobs is None:
|
||||||
token = tokenizer.decode(token_id)
|
token = tokenizer.decode(token_id)
|
||||||
if self.return_tokens_as_token_ids:
|
if should_return_as_token_id:
|
||||||
token = f"token_id:{token_id}"
|
token = f"token_id:{token_id}"
|
||||||
|
|
||||||
logprobs_content.append(
|
logprobs_content.append(
|
||||||
@ -898,16 +905,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
step_token,
|
step_token,
|
||||||
token_id,
|
token_id,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
self.return_tokens_as_token_ids,
|
should_return_as_token_id,
|
||||||
),
|
),
|
||||||
logprob=max(step_token.logprob, -9999.0),
|
logprob=max(step_token.logprob, -9999.0),
|
||||||
bytes=None if step_decoded is None else list(
|
bytes=None if step_decoded is None else list(
|
||||||
step_decoded.encode("utf-8", errors="replace")),
|
step_decoded.encode("utf-8", errors="replace")),
|
||||||
top_logprobs=self._get_top_logprobs(
|
top_logprobs=self._get_top_logprobs(
|
||||||
step_top_logprobs,
|
step_top_logprobs, num_output_top_logprobs,
|
||||||
num_output_top_logprobs,
|
tokenizer, should_return_as_token_id),
|
||||||
tokenizer,
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
|
|
||||||
return ChatCompletionLogProbs(content=logprobs_content)
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
|
|||||||
@ -316,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
initial_text_offset=previous_text_lens[i],
|
initial_text_offset=previous_text_lens[i],
|
||||||
|
return_as_token_id=request.
|
||||||
|
return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -436,6 +438,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
return_as_token_id=request.return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -477,6 +480,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
num_output_top_logprobs: int,
|
num_output_top_logprobs: int,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
initial_text_offset: int = 0,
|
initial_text_offset: int = 0,
|
||||||
|
return_as_token_id: Optional[bool] = None,
|
||||||
) -> CompletionLogProbs:
|
) -> CompletionLogProbs:
|
||||||
"""Create logprobs for OpenAI Completion API."""
|
"""Create logprobs for OpenAI Completion API."""
|
||||||
out_text_offset: list[int] = []
|
out_text_offset: list[int] = []
|
||||||
@ -486,11 +490,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
last_token_len = 0
|
last_token_len = 0
|
||||||
|
|
||||||
|
should_return_as_token_id = return_as_token_id if \
|
||||||
|
return_as_token_id is not None else self.return_tokens_as_token_ids
|
||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
step_top_logprobs = top_logprobs[i]
|
step_top_logprobs = top_logprobs[i]
|
||||||
if step_top_logprobs is None:
|
if step_top_logprobs is None:
|
||||||
token = tokenizer.decode(token_id)
|
token = tokenizer.decode(token_id)
|
||||||
if self.return_tokens_as_token_ids:
|
if should_return_as_token_id:
|
||||||
token = f"token_id:{token_id}"
|
token = f"token_id:{token_id}"
|
||||||
|
|
||||||
out_tokens.append(token)
|
out_tokens.append(token)
|
||||||
@ -503,7 +509,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
step_token,
|
step_token,
|
||||||
token_id,
|
token_id,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_as_token_id=self.return_tokens_as_token_ids,
|
return_as_token_id=should_return_as_token_id,
|
||||||
)
|
)
|
||||||
token_logprob = max(step_token.logprob, -9999.0)
|
token_logprob = max(step_token.logprob, -9999.0)
|
||||||
|
|
||||||
@ -520,7 +526,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
self._get_decoded_token(top_lp[1],
|
self._get_decoded_token(top_lp[1],
|
||||||
top_lp[0],
|
top_lp[0],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_as_token_id=self.return_tokens_as_token_ids):
|
return_as_token_id=should_return_as_token_id):
|
||||||
max(top_lp[1].logprob, -9999.0)
|
max(top_lp[1].logprob, -9999.0)
|
||||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||||
if num_output_top_logprobs >= i
|
if num_output_top_logprobs >= i
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user