From de23687d168ebeaa8872c27f05b8292bab0fac71 Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Thu, 23 Nov 2023 06:41:44 +0800 Subject: [PATCH] Fix repetition penalty aligned with huggingface (#1577) --- vllm/model_executor/layers/sampler.py | 76 +++++++++++++++++---------- vllm/sampling_params.py | 6 +-- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9fcc2f20675c0..c874ec5921155 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -21,7 +21,7 @@ class Sampler(nn.Module): 1. Discard the hidden states that are not used for sampling (i.e., all tokens except the final one in each prompt). 2. Compute the logits for the next tokens. - 3. Apply presence and frequency penalties. + 3. Apply presence, frequency and repetition penalties. 4. Apply temperature scaling. 5. Apply top-p and top-k truncation. 6. Sample the next tokens. @@ -50,14 +50,12 @@ class Sampler(nn.Module): # Apply logits processors (if any). logits = _apply_logits_processors(logits, input_metadata) # Apply presence and frequency penalties. - output_tokens = _get_output_tokens(input_metadata) - assert len(output_tokens) == logits.shape[0] presence_penalties, frequency_penalties, repetition_penalties = ( _get_penalties(input_metadata)) assert len(presence_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0] assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, output_tokens, presence_penalties, + logits = _apply_penalties(logits, input_metadata, presence_penalties, frequency_penalties, repetition_penalties) # Apply temperature scaling. @@ -146,7 +144,10 @@ def _get_penalties( return presence_penalties, frequency_penalties, repetition_penalties -def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: +def _get_prompt_and_output_tokens( + input_metadata: InputMetadata +) -> Tuple[List[List[int]], List[List[int]]]: + prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group @@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: # NOTE: prompt token positions do not need output tokens to # compute penalties. prompt_len = input_metadata.prompt_lens[i] + prompt_tokens.extend([] for _ in range(prompt_len - 1)) output_tokens.extend([] for _ in range(prompt_len - 1)) for seq_id in seq_ids: seq_data = input_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) output_tokens.append(seq_data.output_token_ids) - return output_tokens + return prompt_tokens, output_tokens + + +def _get_bin_counts_and_mask( + logits: torch.Tensor, + tokens: List[List[int]], + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + max_len = max(len(tokens) for tokens in tokens) + padded_tokens = [ + tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens + ] + tokens_tensor = torch.tensor(padded_tokens, + dtype=torch.long, + device=logits.device) + + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=logits.device) + bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask def _apply_logits_processors(logits: torch.Tensor, @@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor, def _apply_penalties( logits: torch.Tensor, - output_tokens: List[List[int]], + input_metadata: InputMetadata, presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], ) -> torch.Tensor: num_seqs, vocab_size = logits.shape for i in range(num_seqs): - if not output_tokens[i]: - continue p = presence_penalties[i] f = frequency_penalties[i] r = repetition_penalties[i] @@ -206,24 +233,15 @@ def _apply_penalties( # Return early if all sequences have zero penalties. return logits - max_output_len = max(len(tokens) for tokens in output_tokens) - padded_output_tokens = [ - tokens + [vocab_size] * (max_output_len - len(tokens)) - for tokens in output_tokens - ] - output_tokens_tensor = torch.tensor(padded_output_tokens, - dtype=torch.long, - device=logits.device) + prompt_tokens, output_tokens = ( + _get_prompt_and_output_tokens(input_metadata)) + assert len(prompt_tokens) == logits.shape[0] + assert len(output_tokens) == logits.shape[0] - # Compute the bin counts for the output tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=logits.device) - bin_counts.scatter_add_(1, output_tokens_tensor, - torch.ones_like(output_tokens_tensor)) - bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. - mask = bin_counts > 0 + prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask( + logits, prompt_tokens, vocab_size, num_seqs) + output_bin_counts, output_mask = _get_bin_counts_and_mask( + logits, output_tokens, vocab_size, num_seqs) repetition_penalties = torch.tensor(repetition_penalties, dtype=logits.dtype, @@ -236,14 +254,14 @@ def _apply_penalties( device=logits.device) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~mask] = 1.0 + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f9eca1a9fc43c..5a08169c48a36 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -42,9 +42,9 @@ class SamplingParams: model to use new tokens, while values < 0 encourage the model to repeat tokens. repetition_penalty: Float that penalizes new tokens based on whether - they appear in the generated text so far. Values > 1 encourage the - model to use new tokens, while values < 1 encourage the model to - repeat tokens. + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. temperature: Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.