[Bugfix] using len(tokenizer) instead of tokenizer.vocab_size in AllowedTokenIdsLogitsProcessor (#11156)

This commit is contained in:
zhangjf 2024-12-13 23:56:19 +08:00 committed by GitHub
parent c31d4a57a6
commit 5b0ed8391d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,7 +71,7 @@ def get_logits_processors(
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= tokenizer.vocab_size:
if token_id < 0 or token_id >= len(tokenizer):
raise ValueError(f"token_id {token_id} in logit_bias contains "
"out-of-vocab token id")
@ -81,6 +81,6 @@ def get_logits_processors(
if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))
frozenset(allowed_token_ids), len(tokenizer)))
return logits_processors