[Hardware][Gaudi][Bugfix] Fix error for guided decoding (#12317)

This commit is contained in:
Yu-Zhou 2025-02-14 20:36:49 +08:00 committed by GitHub
parent 7734e9a291
commit c9e2d644e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,8 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.platforms import current_platform
class BaseLogitsProcessor:
@ -91,7 +93,14 @@ class BaseLogitsProcessor:
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask)
if current_platform.is_hpu():
# Workaround for HPU bug where add_() raise RuntimeError:
# synNodeCreateWithId failed for node: strided_insert
# with synStatus 1 [Invalid argument], hopefully it will
# be fixed in the future releases of the HPU runtime.
scores = scores.add(mask)
else:
scores.add_(mask)
return scores