[Bugfix] Fix crash with llama 3.2 vision models and guided decoding (#9631)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: pavlo-ruban <pavlo.ruban@servicenow.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-10-25 16:42:56 -06:00 committed by GitHub
parent 228cfbd03f
commit 6567e13724
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,11 +15,11 @@
# limitations under the License.
import copy
import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np
import torch
from lark import Lark
from outlines import grammars
@ -77,9 +77,17 @@ class BaseLogitsProcessor:
f"Unsupported instruction type {type(instruction)}")
mask = torch.full((scores.shape[-1], ),
-math.inf,
-torch.inf,
device=scores.device)
mask[allowed_tokens] = 0
# The tokenizer may support more token ids than the model can generate,
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
# but scores.shape == torch.Size([128256])
# Using NumPy is faster for filtering token ids
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask)
return scores