[Bugfix] fix logit processor excceed vocab size issue (#6927)

This commit is contained in:
Fei 2024-07-31 01:16:01 -07:00 committed by GitHub
parent 533d1932d2
commit c0644cf9ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,6 +58,12 @@ def get_logits_processors(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= tokenizer.vocab_size:
raise ValueError("token_id in logit_bias contains "
"out-of-vocab token id")
def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():