mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:35:52 +08:00
[Hardware][Gaudi][Bugfix] Fix error for guided decoding (#12317)
This commit is contained in:
parent
7734e9a291
commit
c9e2d644e7
@ -32,6 +32,8 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
@ -91,7 +93,14 @@ class BaseLogitsProcessor:
|
||||
allowed_tokens = allowed_tokens.masked_select(
|
||||
allowed_tokens < scores.shape[-1])
|
||||
mask.index_fill_(0, allowed_tokens, 0)
|
||||
scores.add_(mask)
|
||||
if current_platform.is_hpu():
|
||||
# Workaround for HPU bug where add_() raise RuntimeError:
|
||||
# synNodeCreateWithId failed for node: strided_insert
|
||||
# with synStatus 1 [Invalid argument], hopefully it will
|
||||
# be fixed in the future releases of the HPU runtime.
|
||||
scores = scores.add(mask)
|
||||
else:
|
||||
scores.add_(mask)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user