mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:34:57 +08:00
Fix repetition penalty aligned with huggingface (#1577)
This commit is contained in:
parent
4cea74c73b
commit
de23687d16
@ -21,7 +21,7 @@ class Sampler(nn.Module):
|
|||||||
1. Discard the hidden states that are not used for sampling (i.e., all
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
||||||
tokens except the final one in each prompt).
|
tokens except the final one in each prompt).
|
||||||
2. Compute the logits for the next tokens.
|
2. Compute the logits for the next tokens.
|
||||||
3. Apply presence and frequency penalties.
|
3. Apply presence, frequency and repetition penalties.
|
||||||
4. Apply temperature scaling.
|
4. Apply temperature scaling.
|
||||||
5. Apply top-p and top-k truncation.
|
5. Apply top-p and top-k truncation.
|
||||||
6. Sample the next tokens.
|
6. Sample the next tokens.
|
||||||
@ -50,14 +50,12 @@ class Sampler(nn.Module):
|
|||||||
# Apply logits processors (if any).
|
# Apply logits processors (if any).
|
||||||
logits = _apply_logits_processors(logits, input_metadata)
|
logits = _apply_logits_processors(logits, input_metadata)
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
output_tokens = _get_output_tokens(input_metadata)
|
|
||||||
assert len(output_tokens) == logits.shape[0]
|
|
||||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||||
_get_penalties(input_metadata))
|
_get_penalties(input_metadata))
|
||||||
assert len(presence_penalties) == logits.shape[0]
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
assert len(repetition_penalties) == logits.shape[0]
|
assert len(repetition_penalties) == logits.shape[0]
|
||||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
logits = _apply_penalties(logits, input_metadata, presence_penalties,
|
||||||
frequency_penalties, repetition_penalties)
|
frequency_penalties, repetition_penalties)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
@ -146,7 +144,10 @@ def _get_penalties(
|
|||||||
return presence_penalties, frequency_penalties, repetition_penalties
|
return presence_penalties, frequency_penalties, repetition_penalties
|
||||||
|
|
||||||
|
|
||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
def _get_prompt_and_output_tokens(
|
||||||
|
input_metadata: InputMetadata
|
||||||
|
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||||
|
prompt_tokens: List[List[int]] = []
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
|||||||
# NOTE: prompt token positions do not need output tokens to
|
# NOTE: prompt token positions do not need output tokens to
|
||||||
# compute penalties.
|
# compute penalties.
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = input_metadata.prompt_lens[i]
|
||||||
|
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
return output_tokens
|
return prompt_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bin_counts_and_mask(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
vocab_size: int,
|
||||||
|
num_seqs: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
max_len = max(len(tokens) for tokens in tokens)
|
||||||
|
padded_tokens = [
|
||||||
|
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
|
||||||
|
]
|
||||||
|
tokens_tensor = torch.tensor(padded_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
|
||||||
|
# Compute the bin counts for the tokens.
|
||||||
|
# vocab_size + 1 for padding.
|
||||||
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
|
||||||
|
bin_counts = bin_counts[:, :vocab_size]
|
||||||
|
mask = bin_counts > 0
|
||||||
|
|
||||||
|
return bin_counts, mask
|
||||||
|
|
||||||
|
|
||||||
def _apply_logits_processors(logits: torch.Tensor,
|
def _apply_logits_processors(logits: torch.Tensor,
|
||||||
@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
|
|||||||
|
|
||||||
def _apply_penalties(
|
def _apply_penalties(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
output_tokens: List[List[int]],
|
input_metadata: InputMetadata,
|
||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
repetition_penalties: List[float],
|
repetition_penalties: List[float],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs, vocab_size = logits.shape
|
num_seqs, vocab_size = logits.shape
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
if not output_tokens[i]:
|
|
||||||
continue
|
|
||||||
p = presence_penalties[i]
|
p = presence_penalties[i]
|
||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
r = repetition_penalties[i]
|
r = repetition_penalties[i]
|
||||||
@ -206,24 +233,15 @@ def _apply_penalties(
|
|||||||
# Return early if all sequences have zero penalties.
|
# Return early if all sequences have zero penalties.
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
max_output_len = max(len(tokens) for tokens in output_tokens)
|
prompt_tokens, output_tokens = (
|
||||||
padded_output_tokens = [
|
_get_prompt_and_output_tokens(input_metadata))
|
||||||
tokens + [vocab_size] * (max_output_len - len(tokens))
|
assert len(prompt_tokens) == logits.shape[0]
|
||||||
for tokens in output_tokens
|
assert len(output_tokens) == logits.shape[0]
|
||||||
]
|
|
||||||
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=logits.device)
|
|
||||||
|
|
||||||
# Compute the bin counts for the output tokens.
|
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
|
||||||
# vocab_size + 1 for padding.
|
logits, prompt_tokens, vocab_size, num_seqs)
|
||||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
||||||
dtype=torch.long,
|
logits, output_tokens, vocab_size, num_seqs)
|
||||||
device=logits.device)
|
|
||||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
|
||||||
torch.ones_like(output_tokens_tensor))
|
|
||||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
|
||||||
mask = bin_counts > 0
|
|
||||||
|
|
||||||
repetition_penalties = torch.tensor(repetition_penalties,
|
repetition_penalties = torch.tensor(repetition_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
@ -236,14 +254,14 @@ def _apply_penalties(
|
|||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||||
repetition_penalties[~mask] = 1.0
|
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
||||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||||
logits * repetition_penalties)
|
logits * repetition_penalties)
|
||||||
|
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||||
logits -= presence_penalties.unsqueeze(dim=1) * mask
|
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,9 +42,9 @@ class SamplingParams:
|
|||||||
model to use new tokens, while values < 0 encourage the model to
|
model to use new tokens, while values < 0 encourage the model to
|
||||||
repeat tokens.
|
repeat tokens.
|
||||||
repetition_penalty: Float that penalizes new tokens based on whether
|
repetition_penalty: Float that penalizes new tokens based on whether
|
||||||
they appear in the generated text so far. Values > 1 encourage the
|
they appear in the prompt and the generated text so far. Values > 1
|
||||||
model to use new tokens, while values < 1 encourage the model to
|
encourage the model to use new tokens, while values < 1 encourage
|
||||||
repeat tokens.
|
the model to repeat tokens.
|
||||||
temperature: Float that controls the randomness of the sampling. Lower
|
temperature: Float that controls the randomness of the sampling. Lower
|
||||||
values make the model more deterministic, while higher values make
|
values make the model more deterministic, while higher values make
|
||||||
the model more random. Zero means greedy sampling.
|
the model more random. Zero means greedy sampling.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user