mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 00:48:45 +08:00
[Frontend] Represent tokens with identifiable strings (#6626)
This commit is contained in:
parent
740374d456
commit
5689e256ba
@ -55,8 +55,9 @@ def zephyr_pa_files():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
|
||||
args = [
|
||||
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
zephyr_pa_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
|
||||
"128",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(default_server_args):
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
|
||||
83
tests/entrypoints/openai/test_return_tokens_as_ids.py
Normal file
83
tests/entrypoints/openai/test_return_tokens_as_ids.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Separate these tests out from test_completion and test_chat, because they
|
||||
# require launching a second server with a different flag. Running both servers
|
||||
# at the same time on a single node will OOM.
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import default_server_args # noqa: F401
|
||||
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||
from .test_completion import zephyr_lora_files # noqa: F401
|
||||
from .test_completion import zephyr_pa_files # noqa: F401
|
||||
from .test_completion import MODEL_NAME
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_with_return_tokens_as_token_ids_flag(
|
||||
default_server_args): # noqa: F811
|
||||
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
prompt="Say 'Hello, world! 🎉'",
|
||||
echo=True,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1)
|
||||
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# Check that the token representations are consistent between raw tokens
|
||||
# and top_logprobs
|
||||
# Slice off the first one, because there's no scoring associated with BOS
|
||||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
|
||||
top_logprob_keys = [
|
||||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
|
||||
]
|
||||
assert token_strs[1:] == top_logprob_keys
|
||||
|
||||
# Check that decoding the tokens gives the expected text
|
||||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
|
||||
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
@ -254,6 +254,7 @@ async def build_server(
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
@ -262,6 +263,7 @@ async def build_server(
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
|
||||
@ -128,6 +128,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"using @app.middleware('http'). "
|
||||
"If a class is provided, vLLM will add it to the server "
|
||||
"using app.add_middleware(). ")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When --max-logprobs is specified, represents single tokens as"
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that"
|
||||
"are not JSON-encodable can be identified.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
@ -50,13 +50,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger)
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
|
||||
@ -522,11 +524,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(
|
||||
token=(token := self._get_decoded_token(p[1], p[0],
|
||||
tokenizer)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(token.encode("utf-8", errors="replace")))
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
@ -546,6 +551,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
@ -553,7 +560,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=step_top_logprobs[token_id].decoded_token,
|
||||
token=self._get_decoded_token(
|
||||
step_top_logprobs[token_id], token_id, tokenizer,
|
||||
self.return_tokens_as_token_ids),
|
||||
logprob=max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0),
|
||||
bytes=list(
|
||||
|
||||
@ -51,13 +51,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger)
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
@ -430,12 +432,17 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||
token_id, tokenizer)
|
||||
token = self._get_decoded_token(
|
||||
step_top_logprobs[token_id],
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
out_tokens.append(token)
|
||||
@ -448,7 +455,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
|
||||
@ -68,6 +68,7 @@ class OpenAIServing:
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -102,6 +103,7 @@ class OpenAIServing:
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
@ -384,11 +386,13 @@ class OpenAIServing:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> str:
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user