mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:55:32 +08:00
[Hardware][Gaudi][Bugfix] Fix error for guided decoding (#12317)
This commit is contained in:
parent
7734e9a291
commit
c9e2d644e7
@ -32,6 +32,8 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
@ -91,6 +93,13 @@ class BaseLogitsProcessor:
|
|||||||
allowed_tokens = allowed_tokens.masked_select(
|
allowed_tokens = allowed_tokens.masked_select(
|
||||||
allowed_tokens < scores.shape[-1])
|
allowed_tokens < scores.shape[-1])
|
||||||
mask.index_fill_(0, allowed_tokens, 0)
|
mask.index_fill_(0, allowed_tokens, 0)
|
||||||
|
if current_platform.is_hpu():
|
||||||
|
# Workaround for HPU bug where add_() raise RuntimeError:
|
||||||
|
# synNodeCreateWithId failed for node: strided_insert
|
||||||
|
# with synStatus 1 [Invalid argument], hopefully it will
|
||||||
|
# be fixed in the future releases of the HPU runtime.
|
||||||
|
scores = scores.add(mask)
|
||||||
|
else:
|
||||||
scores.add_(mask)
|
scores.add_(mask)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user