Fix repetition penalty aligned with huggingface (#1577)

This commit is contained in:
ljss 2023-11-23 06:41:44 +08:00 committed by GitHub
parent 4cea74c73b
commit de23687d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 32 deletions

View File

@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
@ -50,14 +50,12 @@ class Sampler(nn.Module):
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties, repetition_penalties = (
_get_penalties(input_metadata))
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties,
logits = _apply_penalties(logits, input_metadata, presence_penalties,
frequency_penalties, repetition_penalties)
# Apply temperature scaling.
@ -146,7 +144,10 @@ def _get_penalties(
return presence_penalties, frequency_penalties, repetition_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
def _get_prompt_and_output_tokens(
input_metadata: InputMetadata
) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
return output_tokens
return prompt_tokens, output_tokens
def _get_bin_counts_and_mask(
logits: torch.Tensor,
tokens: List[List[int]],
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
max_len = max(len(tokens) for tokens in tokens)
padded_tokens = [
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
]
tokens_tensor = torch.tensor(padded_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def _apply_logits_processors(logits: torch.Tensor,
@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
input_metadata: InputMetadata,
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
) -> torch.Tensor:
num_seqs, vocab_size = logits.shape
for i in range(num_seqs):
if not output_tokens[i]:
continue
p = presence_penalties[i]
f = frequency_penalties[i]
r = repetition_penalties[i]
@ -206,24 +233,15 @@ def _apply_penalties(
# Return early if all sequences have zero penalties.
return logits
max_output_len = max(len(tokens) for tokens in output_tokens)
padded_output_tokens = [
tokens + [vocab_size] * (max_output_len - len(tokens))
for tokens in output_tokens
]
output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
prompt_tokens, output_tokens = (
_get_prompt_and_output_tokens(input_metadata))
assert len(prompt_tokens) == logits.shape[0]
assert len(output_tokens) == logits.shape[0]
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
mask = bin_counts > 0
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
logits, prompt_tokens, vocab_size, num_seqs)
output_bin_counts, output_mask = _get_bin_counts_and_mask(
logits, output_tokens, vocab_size, num_seqs)
repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype,
@ -236,14 +254,14 @@ def _apply_penalties(
device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~mask] = 1.0
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * mask
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits

View File

@ -42,9 +42,9 @@ class SamplingParams:
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the
model to use new tokens, while values < 1 encourage the model to
repeat tokens.
they appear in the prompt and the generated text so far. Values > 1
encourage the model to use new tokens, while values < 1 encourage
the model to repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.