mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
Fix repetition penalty aligned with huggingface (#1577)
This commit is contained in:
parent
4cea74c73b
commit
de23687d16
@ -21,7 +21,7 @@ class Sampler(nn.Module):
|
||||
1. Discard the hidden states that are not used for sampling (i.e., all
|
||||
tokens except the final one in each prompt).
|
||||
2. Compute the logits for the next tokens.
|
||||
3. Apply presence and frequency penalties.
|
||||
3. Apply presence, frequency and repetition penalties.
|
||||
4. Apply temperature scaling.
|
||||
5. Apply top-p and top-k truncation.
|
||||
6. Sample the next tokens.
|
||||
@ -50,14 +50,12 @@ class Sampler(nn.Module):
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, input_metadata)
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||
_get_penalties(input_metadata))
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
assert len(repetition_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
logits = _apply_penalties(logits, input_metadata, presence_penalties,
|
||||
frequency_penalties, repetition_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
@ -146,7 +144,10 @@ def _get_penalties(
|
||||
return presence_penalties, frequency_penalties, repetition_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
def _get_prompt_and_output_tokens(
|
||||
input_metadata: InputMetadata
|
||||
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
prompt_tokens: List[List[int]] = []
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
# NOTE: prompt token positions do not need output tokens to
|
||||
# compute penalties.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
return prompt_tokens, output_tokens
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
logits: torch.Tensor,
|
||||
tokens: List[List[int]],
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
max_len = max(len(tokens) for tokens in tokens)
|
||||
padded_tokens = [
|
||||
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
|
||||
]
|
||||
tokens_tensor = torch.tensor(padded_tokens,
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_logits_processors(logits: torch.Tensor,
|
||||
@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
output_tokens: List[List[int]],
|
||||
input_metadata: InputMetadata,
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
) -> torch.Tensor:
|
||||
num_seqs, vocab_size = logits.shape
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
r = repetition_penalties[i]
|
||||
@ -206,24 +233,15 @@ def _apply_penalties(
|
||||
# Return early if all sequences have zero penalties.
|
||||
return logits
|
||||
|
||||
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||
padded_output_tokens = [
|
||||
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
prompt_tokens, output_tokens = (
|
||||
_get_prompt_and_output_tokens(input_metadata))
|
||||
assert len(prompt_tokens) == logits.shape[0]
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
|
||||
# Compute the bin counts for the output tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||
torch.ones_like(output_tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||
mask = bin_counts > 0
|
||||
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
|
||||
logits, prompt_tokens, vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
||||
logits, output_tokens, vocab_size, num_seqs)
|
||||
|
||||
repetition_penalties = torch.tensor(repetition_penalties,
|
||||
dtype=logits.dtype,
|
||||
@ -236,14 +254,14 @@ def _apply_penalties(
|
||||
device=logits.device)
|
||||
|
||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||
repetition_penalties[~mask] = 1.0
|
||||
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||
logits * repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * mask
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
@ -42,9 +42,9 @@ class SamplingParams:
|
||||
model to use new tokens, while values < 0 encourage the model to
|
||||
repeat tokens.
|
||||
repetition_penalty: Float that penalizes new tokens based on whether
|
||||
they appear in the generated text so far. Values > 1 encourage the
|
||||
model to use new tokens, while values < 1 encourage the model to
|
||||
repeat tokens.
|
||||
they appear in the prompt and the generated text so far. Values > 1
|
||||
encourage the model to use new tokens, while values < 1 encourage
|
||||
the model to repeat tokens.
|
||||
temperature: Float that controls the randomness of the sampling. Lower
|
||||
values make the model more deterministic, while higher values make
|
||||
the model more random. Zero means greedy sampling.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user