diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b72282072..26e2d29ffd04c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +import itertools import random from typing import List, Optional, Tuple from unittest.mock import patch @@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, - stop_token_ids=None): + *, + stop_token_ids: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, + # requesting prompt_logprobs changes the structure of `logits` + prompt_logprobs=prompt_logprobs, ) sampling_params.eos_token_id = eos_token_id return sampling_params @@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): expected_penalization = [] sequence_metadata_list = [] + # 20% chance to generate seq group metadata list with all prompts + is_prompt = random.random() < 0.2 while batch_size > 0: - # 20% chance to generate prompt seq group with single sequence - is_prompt = random.random() < 0.2 num_seqs = 1 if is_prompt else random.randint(1, batch_size) eos_token_id = random.randint(0, VOCAB_SIZE - 1) @@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): seq_group_penalization = [] for _ in range(num_seqs): num_input = random.randint(1, 100) - num_generated = random.randint(1, 100) if not is_prompt else 0 + num_generated = 0 if is_prompt else random.randint(1, 100) seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ] } + prompt_with_penalization_and_prompt_logprobs = { + "expected_penalization": [False, False, True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=3), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + ] + } + stop_penalizing_after_min_tokens = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - simple_combination = { - "expected_penalization": [True, False, False], + prompt_combination = { + "expected_penalization": [False, True, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=2), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_3", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + stop_token_ids = [1, 999, 37, 37] # intentional duplication + decode_combination = { + "expected_penalization": [True, False, False, True, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", @@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ), SequenceGroupMetadata( request_id="test_2", - is_prompt=True, + is_prompt=False, seq_data={ - next(seq_id_counter): create_sequence_data(), + next(seq_id_counter): + create_sequence_data(num_generated=20), + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=10), }, sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), + 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), block_tables={}, - ) + ), ] } @@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): test_cases = [ prompt_without_penalization, prompt_with_penalization, + prompt_with_penalization_and_prompt_logprobs, stop_penalizing_after_min_tokens, - simple_combination, + prompt_combination, + decode_combination, ] else: test_cases = [generate_test_case()] @@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def run_test_case(*, expected_penalization=None, seq_group_metadata_list=None): - assert expected_penalization, "Invalid test case" - assert seq_group_metadata_list, "Invalid test case" + assert expected_penalization, \ + "Invalid test case, need expected_penalization" + assert seq_group_metadata_list, \ + "Invalid test case, need seq_group_metadata_list" batch_size = 0 prompt_lens = [] - sampling_params_per_seq = [] + sampling_params_per_row = [] for sgm in seq_group_metadata_list: - num_seqs = len(sgm.seq_data) - batch_size += num_seqs sampling_params = sgm.sampling_params - for seq_id in sgm.seq_data: - prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) - sampling_params_per_seq.append(sampling_params) + + num_rows = len(sgm.seq_data) + if sgm.is_prompt: + # a prompt seq_group has only one sequence + seq_data = next(iter(sgm.seq_data.values())) + prompt_len = seq_data.get_prompt_len() + prompt_lens.append(prompt_len) + + if sgm.sampling_params.prompt_logprobs: + # with prompt_logprobs each token in the prompt has a row in + # logits + num_rows = prompt_len + + batch_size += num_rows + sampling_params_per_row.extend( + itertools.repeat(sampling_params, num_rows)) + + assert len( + expected_penalization + ) == batch_size, \ + ("Invalid test case, expected_penalization does not match computed" + "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, - prompt_lens=prompt_lens, - subquery_lens=prompt_lens) + prompt_lens=prompt_lens if prompt_lens else None, + subquery_lens=prompt_lens if prompt_lens else None) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_seq)): + zip(expected_penalization, sampling_params_per_row)): tokens_to_check = [sampling_params.eos_token_id] if sampling_params.stop_token_ids: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3a..03bf38caebe0e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,6 +27,12 @@ class Sampler(nn.Module): 6. Sample the next tokens. Here, each sequence group within the batch can have different sampling parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. """ def forward( @@ -106,7 +112,16 @@ def _apply_min_tokens_penalty( # list of indices in logits that will be set to -inf logits_to_penalize = [] start_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + + # handle prompt_logprobs by skipping rows in logits added for the prompt + # tokens (prompt logprobs are not penalized) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + start_idx += sampling_metadata.prompt_lens[i] - 1 + min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] @@ -132,6 +147,8 @@ def _apply_min_tokens_penalty( # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") + # verifies that no rows in logits were missed unexpectedly + assert start_idx == logits.shape[0] return logits