mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-07 01:52:20 +08:00
[Bugfix] fixed top_logprobs: -1 does not appear to work as intended (#26470)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
cddce79fda
commit
910abdbd08
@ -7,12 +7,23 @@ import openai # use the official client for correctness check
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# # any model with a chat template should work here
|
# # any model with a chat template should work here
|
||||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocab_size(model_name):
|
||||||
|
config = ModelConfig(
|
||||||
|
model=model_name,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
return config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
@ -107,6 +118,7 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
|
|||||||
completion = await client.chat.completions.create(
|
completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=1,
|
||||||
extra_body={
|
extra_body={
|
||||||
"top_logprobs": -1,
|
"top_logprobs": -1,
|
||||||
"logprobs": "true",
|
"logprobs": "true",
|
||||||
@ -115,3 +127,6 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
|
|||||||
assert completion.choices[0].logprobs is not None
|
assert completion.choices[0].logprobs is not None
|
||||||
assert completion.choices[0].logprobs.content is not None
|
assert completion.choices[0].logprobs.content is not None
|
||||||
assert len(completion.choices[0].logprobs.content) > 0
|
assert len(completion.choices[0].logprobs.content) > 0
|
||||||
|
assert len(
|
||||||
|
completion.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
) == get_vocab_size(MODEL_NAME)
|
||||||
|
|||||||
@ -1643,7 +1643,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
bytes=list(token.encode("utf-8", errors="replace")),
|
bytes=list(token.encode("utf-8", errors="replace")),
|
||||||
)
|
)
|
||||||
for i, p in enumerate(logprobs.items())
|
for i, p in enumerate(logprobs.items())
|
||||||
if top_logprobs and i < top_logprobs
|
if (top_logprobs and i < top_logprobs or top_logprobs == -1)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_chat_logprobs(
|
def _create_chat_logprobs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user