[Frontend] Allow return_tokens_as_token_ids to be passed as a request param (#14066)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett 2025-03-05 01:30:40 -05:00 committed by GitHub
parent dae9ec464c
commit 32985bed7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 25 deletions

View File

@ -17,18 +17,28 @@ from .test_completion import MODEL_NAME
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag( def server_fixture(request, default_server_args): # noqa: F811
default_server_args): # noqa: F811 use_server_flag = request.param
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"] if use_server_flag:
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
yield remote_server with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield (remote_server, True)
else:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield (remote_server, False)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
async def test_completion_return_tokens_as_token_ids_completion( async def test_completion_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag): server_fixture):
async with server_with_return_tokens_as_token_ids_flag.get_async_client( server, use_server_flag = server_fixture
) as client: request_args = {}
if not use_server_flag:
request_args["return_tokens_as_token_ids"] = True
async with server.get_async_client() as client:
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -39,7 +49,8 @@ async def test_completion_return_tokens_as_token_ids_completion(
echo=True, echo=True,
temperature=0, temperature=0,
max_tokens=10, max_tokens=10,
logprobs=1) logprobs=1,
extra_body=request_args)
text = completion.choices[0].text text = completion.choices[0].text
token_strs = completion.choices[0].logprobs.tokens token_strs = completion.choices[0].logprobs.tokens
@ -60,10 +71,14 @@ async def test_completion_return_tokens_as_token_ids_completion(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_return_tokens_as_token_ids_completion( @pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
server_with_return_tokens_as_token_ids_flag): async def test_chat_return_tokens_as_token_ids_completion(server_fixture):
async with server_with_return_tokens_as_token_ids_flag.get_async_client( server, use_server_flag = server_fixture
) as client: request_args = {}
if not use_server_flag:
request_args["return_tokens_as_token_ids"] = True
async with server.get_async_client() as client:
response = await client.chat.completions.create( response = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
# Include Unicode characters to test for dividing a single # Include Unicode characters to test for dividing a single
@ -78,7 +93,8 @@ async def test_chat_return_tokens_as_token_ids_completion(
}], }],
temperature=0, temperature=0,
max_tokens=8, max_tokens=8,
logprobs=True) logprobs=True,
extra_body=request_args)
text = response.choices[0].message.content text = response.choices[0].message.content
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

View File

@ -369,6 +369,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': " "arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}.")) "{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -739,6 +745,12 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': " "arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}.")) "{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
# doc: end-completion-extra-params # doc: end-completion-extra-params

View File

@ -450,6 +450,8 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs=output.logprobs, top_logprobs=output.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
return_as_token_id=request.
return_tokens_as_token_ids,
) )
else: else:
logprobs = None logprobs = None
@ -705,6 +707,7 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
return_as_token_id=request.return_tokens_as_token_ids,
) )
else: else:
logprobs = None logprobs = None
@ -852,13 +855,14 @@ class OpenAIServingChat(OpenAIServing):
def _get_top_logprobs( def _get_top_logprobs(
self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], self, logprobs: dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: AnyTokenizer) -> list[ChatCompletionLogProb]: tokenizer: AnyTokenizer,
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb(token=(token := self._get_decoded_token( ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1], p[1],
p[0], p[0],
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)), return_as_token_id=should_return_as_token_id)),
logprob=max(p[1].logprob, -9999.0), logprob=max(p[1].logprob, -9999.0),
bytes=list( bytes=list(
token.encode("utf-8", errors="replace"))) token.encode("utf-8", errors="replace")))
@ -872,15 +876,18 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
return_as_token_id: Optional[bool] = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs.""" """Create OpenAI-style logprobs."""
logprobs_content: list[ChatCompletionLogProbsContent] = [] logprobs_content: list[ChatCompletionLogProbsContent] = []
should_return_as_token_id = return_as_token_id if \
return_as_token_id is not None else self.return_tokens_as_token_ids
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
logprobs_content.append( logprobs_content.append(
@ -898,16 +905,14 @@ class OpenAIServingChat(OpenAIServing):
step_token, step_token,
token_id, token_id,
tokenizer, tokenizer,
self.return_tokens_as_token_ids, should_return_as_token_id,
), ),
logprob=max(step_token.logprob, -9999.0), logprob=max(step_token.logprob, -9999.0),
bytes=None if step_decoded is None else list( bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")), step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs( top_logprobs=self._get_top_logprobs(
step_top_logprobs, step_top_logprobs, num_output_top_logprobs,
num_output_top_logprobs, tokenizer, should_return_as_token_id),
tokenizer,
),
)) ))
return ChatCompletionLogProbs(content=logprobs_content) return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -316,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i], initial_text_offset=previous_text_lens[i],
return_as_token_id=request.
return_tokens_as_token_ids,
) )
else: else:
logprobs = None logprobs = None
@ -436,6 +438,7 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
return_as_token_id=request.return_tokens_as_token_ids,
) )
else: else:
logprobs = None logprobs = None
@ -477,6 +480,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
initial_text_offset: int = 0, initial_text_offset: int = 0,
return_as_token_id: Optional[bool] = None,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API.""" """Create logprobs for OpenAI Completion API."""
out_text_offset: list[int] = [] out_text_offset: list[int] = []
@ -486,11 +490,13 @@ class OpenAIServingCompletion(OpenAIServing):
last_token_len = 0 last_token_len = 0
should_return_as_token_id = return_as_token_id if \
return_as_token_id is not None else self.return_tokens_as_token_ids
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
out_tokens.append(token) out_tokens.append(token)
@ -503,7 +509,7 @@ class OpenAIServingCompletion(OpenAIServing):
step_token, step_token,
token_id, token_id,
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids, return_as_token_id=should_return_as_token_id,
) )
token_logprob = max(step_token.logprob, -9999.0) token_logprob = max(step_token.logprob, -9999.0)
@ -520,7 +526,7 @@ class OpenAIServingCompletion(OpenAIServing):
self._get_decoded_token(top_lp[1], self._get_decoded_token(top_lp[1],
top_lp[0], top_lp[0],
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids): return_as_token_id=should_return_as_token_id):
max(top_lp[1].logprob, -9999.0) max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items()) for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i if num_output_top_logprobs >= i