Fix repetition penalty aligned with huggingface (#1577)

This commit is contained in:
ljss 2023-11-23 06:41:44 +08:00 committed by GitHub
parent 4cea74c73b
commit de23687d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 32 deletions

View File

@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all 1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt). tokens except the final one in each prompt).
2. Compute the logits for the next tokens. 2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties. 3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling. 4. Apply temperature scaling.
5. Apply top-p and top-k truncation. 5. Apply top-p and top-k truncation.
6. Sample the next tokens. 6. Sample the next tokens.
@ -50,14 +50,12 @@ class Sampler(nn.Module):
# Apply logits processors (if any). # Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata) logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties, repetition_penalties = ( presence_penalties, frequency_penalties, repetition_penalties = (
_get_penalties(input_metadata)) _get_penalties(input_metadata))
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0] assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, input_metadata, presence_penalties,
frequency_penalties, repetition_penalties) frequency_penalties, repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
@ -146,7 +144,10 @@ def _get_penalties(
return presence_penalties, frequency_penalties, repetition_penalties return presence_penalties, frequency_penalties, repetition_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_prompt_and_output_tokens(
input_metadata: InputMetadata
) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
# NOTE: prompt token positions do not need output tokens to # NOTE: prompt token positions do not need output tokens to
# compute penalties. # compute penalties.
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1)) output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id] seq_data = input_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids) output_tokens.append(seq_data.output_token_ids)
return output_tokens return prompt_tokens, output_tokens
def _get_bin_counts_and_mask(
logits: torch.Tensor,
tokens: List[List[int]],
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
max_len = max(len(tokens) for tokens in tokens)
padded_tokens = [
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
]
tokens_tensor = torch.tensor(padded_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def _apply_logits_processors(logits: torch.Tensor, def _apply_logits_processors(logits: torch.Tensor,
@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
def _apply_penalties( def _apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
output_tokens: List[List[int]], input_metadata: InputMetadata,
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float], repetition_penalties: List[float],
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs, vocab_size = logits.shape num_seqs, vocab_size = logits.shape
for i in range(num_seqs): for i in range(num_seqs):
if not output_tokens[i]:
continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
r = repetition_penalties[i] r = repetition_penalties[i]
@ -206,24 +233,15 @@ def _apply_penalties(
# Return early if all sequences have zero penalties. # Return early if all sequences have zero penalties.
return logits return logits
max_output_len = max(len(tokens) for tokens in output_tokens) prompt_tokens, output_tokens = (
padded_output_tokens = [ _get_prompt_and_output_tokens(input_metadata))
tokens + [vocab_size] * (max_output_len - len(tokens)) assert len(prompt_tokens) == logits.shape[0]
for tokens in output_tokens assert len(output_tokens) == logits.shape[0]
]
output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens. prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
# vocab_size + 1 for padding. logits, prompt_tokens, vocab_size, num_seqs)
bin_counts = torch.zeros((num_seqs, vocab_size + 1), output_bin_counts, output_mask = _get_bin_counts_and_mask(
dtype=torch.long, logits, output_tokens, vocab_size, num_seqs)
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
mask = bin_counts > 0
repetition_penalties = torch.tensor(repetition_penalties, repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype, dtype=logits.dtype,
@ -236,14 +254,14 @@ def _apply_penalties(
device=logits.device) device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~mask] = 1.0 repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties, logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties) logits * repetition_penalties)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * mask logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits return logits

View File

@ -42,9 +42,9 @@ class SamplingParams:
model to use new tokens, while values < 0 encourage the model to model to use new tokens, while values < 0 encourage the model to
repeat tokens. repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the they appear in the prompt and the generated text so far. Values > 1
model to use new tokens, while values < 1 encourage the model to encourage the model to use new tokens, while values < 1 encourage
repeat tokens. the model to repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling. the model more random. Zero means greedy sampling.