mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Bugfix] Support prompt_logprobs==0 (#5217)
This commit is contained in:
parent
f775a07e30
commit
06b2550cbb
@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 1
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1217,8 +1217,9 @@ number: "1" | "2"
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=1)
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
|
||||
@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + output.token_ids
|
||||
top_logprobs = (prompt_logprobs + output.logprobs
|
||||
if request.logprobs else None)
|
||||
if request.logprobs is not None else None)
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
|
||||
@ -233,7 +233,7 @@ def _prepare_seq_groups(
|
||||
logits = hidden_states[selected_token_indices]
|
||||
"""
|
||||
|
||||
if sampling_params.prompt_logprobs:
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
||||
model_output_idx += prompt_logprob_len
|
||||
|
||||
@ -427,7 +427,7 @@ class ModelRunner:
|
||||
[lora_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
is not None else 1))
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user