mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 18:45:01 +08:00
[Bugfix] Fix logits processor when prompt_logprobs is not None (#3899)
This commit is contained in:
parent
c2e00af523
commit
b3104b2a10
62
tests/samplers/test_logits_processor.py
Normal file
62
tests/samplers/test_logits_processor.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
MODELS = ["facebook/opt-125m"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_logits_processor_force_generate(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
|
repeat_times = 2
|
||||||
|
enforced_answers = " vLLM"
|
||||||
|
vllm_token_ids = tokenizer.encode(enforced_answers,
|
||||||
|
add_special_tokens=False)
|
||||||
|
max_tokens = len(vllm_token_ids) * repeat_times
|
||||||
|
|
||||||
|
def pick_vllm(token_ids, logits):
|
||||||
|
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
||||||
|
logits[token_id] = torch.finfo(logits.dtype).max
|
||||||
|
return logits
|
||||||
|
|
||||||
|
params_with_logprobs = SamplingParams(
|
||||||
|
logits_processors=[pick_vllm],
|
||||||
|
prompt_logprobs=3,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test logits_processors when prompt_logprobs is not None
|
||||||
|
vllm_model.model._add_request(
|
||||||
|
prompt=example_prompts[0],
|
||||||
|
sampling_params=params_with_logprobs,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test prompt_logprobs is not None
|
||||||
|
vllm_model.model._add_request(
|
||||||
|
prompt=example_prompts[1],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
prompt_logprobs=3,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
),
|
||||||
|
prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test grouped requests
|
||||||
|
vllm_model.model._add_request(
|
||||||
|
prompt=example_prompts[2],
|
||||||
|
sampling_params=SamplingParams(max_tokens=max_tokens),
|
||||||
|
prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = vllm_model.model._run_engine(False)
|
||||||
|
|
||||||
|
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
||||||
@ -86,8 +86,16 @@ def _apply_logits_processors(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
logits_row_idx = 0
|
logits_row_idx = 0
|
||||||
found_logits_processors = False
|
found_logits_processors = False
|
||||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
logits_processors = sampling_params.logits_processors
|
logits_processors = sampling_params.logits_processors
|
||||||
|
# handle prompt_logprobs by skipping rows in logits added for
|
||||||
|
# the prompt tokens (prompt logprobs are not processed)
|
||||||
|
if (i < sampling_metadata.num_prompts
|
||||||
|
and sampling_params.prompt_logprobs is not None):
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
|
||||||
|
|
||||||
if logits_processors:
|
if logits_processors:
|
||||||
found_logits_processors = True
|
found_logits_processors = True
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
@ -100,5 +108,6 @@ def _apply_logits_processors(
|
|||||||
else:
|
else:
|
||||||
logits_row_idx += len(seq_ids)
|
logits_row_idx += len(seq_ids)
|
||||||
if found_logits_processors:
|
if found_logits_processors:
|
||||||
|
# verifies that no rows in logits were missed unexpectedly
|
||||||
assert logits_row_idx == logits.shape[0]
|
assert logits_row_idx == logits.shape[0]
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user