[Frontend] Represent tokens with identifiable strings (#6626)

This commit is contained in:
Evan Z. Liu 2024-07-24 18:51:00 -07:00 committed by GitHub
parent 740374d456
commit 5689e256ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 138 additions and 19 deletions

View File

@ -55,8 +55,9 @@ def zephyr_pa_files():
@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
"128",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server

View File

@ -0,0 +1,83 @@
# Separate these tests out from test_completion and test_chat, because they
# require launching a second server with a different flag. Running both servers
# at the same time on a single node will OOM.
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME
@pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag(
default_server_args): # noqa: F811
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_completion_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
completion = await client.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
prompt="Say 'Hello, world! 🎉'",
echo=True,
temperature=0,
max_tokens=10,
logprobs=1)
text = completion.choices[0].text
token_strs = completion.choices[0].logprobs.tokens
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Check that the token representations are consistent between raw tokens
# and top_logprobs
# Slice off the first one, because there's no scoring associated with BOS
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
top_logprob_keys = [
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
]
assert token_strs[1:] == top_logprob_keys
# Check that decoding the tokens gives the expected text
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
@pytest.mark.asyncio
async def test_chat_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
response = await client.chat.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
messages=[{
"role": "system",
"content": "You like to respond in only emojis, like 🎉"
}, {
"role": "user",
"content": "Please write some emojis: 🐱🐶🎉"
}],
temperature=0,
max_tokens=8,
logprobs=True)
text = response.choices[0].message.content
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
token_ids = []
for logprob_content in response.choices[0].logprobs.content:
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text

View File

@ -254,6 +254,7 @@ async def build_server(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
@ -262,6 +263,7 @@ async def build_server(
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,

View File

@ -128,6 +128,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as"
"strings of the form 'token_id:{token_id}' so that tokens that"
"are not JSON-encodable can be identified.")
parser = AsyncEngineArgs.add_cli_args(parser)

View File

@ -50,13 +50,15 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
@ -522,11 +524,14 @@ class OpenAIServingChat(OpenAIServing):
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")))
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
@ -546,6 +551,8 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
@ -553,7 +560,9 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(

View File

@ -51,13 +51,15 @@ class OpenAIServingCompletion(OpenAIServing):
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@ -430,12 +432,17 @@ class OpenAIServingCompletion(OpenAIServing):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id, tokenizer)
token = self._get_decoded_token(
step_top_logprobs[token_id],
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
@ -448,7 +455,11 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i

View File

@ -68,6 +68,7 @@ class OpenAIServing:
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__()
@ -102,6 +103,7 @@ class OpenAIServing:
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
@ -384,11 +386,13 @@ class OpenAIServing:
)
@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)