[Bugfix] Fix crash with llama 3.2 vision models and guided decoding (#9631)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: pavlo-ruban <pavlo.ruban@servicenow.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-10-25 16:42:56 -06:00 committed by GitHub
parent 228cfbd03f
commit 6567e13724
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
import copy import copy
import json import json
import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np
import torch import torch
from lark import Lark from lark import Lark
from outlines import grammars from outlines import grammars
@ -77,9 +77,17 @@ class BaseLogitsProcessor:
f"Unsupported instruction type {type(instruction)}") f"Unsupported instruction type {type(instruction)}")
mask = torch.full((scores.shape[-1], ), mask = torch.full((scores.shape[-1], ),
-math.inf, -torch.inf,
device=scores.device) device=scores.device)
mask[allowed_tokens] = 0 # The tokenizer may support more token ids than the model can generate,
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
# but scores.shape == torch.Size([128256])
# Using NumPy is faster for filtering token ids
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask) scores.add_(mask)
return scores return scores