[Bugfix][XPU] fix silu_and_mul (#11823)

Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2025-01-09 00:11:50 +08:00 committed by GitHub
parent 2f7024987e
commit 78f4590b60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View File

@ -64,8 +64,8 @@ class SiluAndMul(CustomOp):
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
self.op = ipex.llm.functional.silu_and_mul
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""

View File

@ -63,8 +63,8 @@ def load_general_plugins():
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa
os.environ['TORCH_COMPILE_DISABLE'] = 'True'
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
if current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
@ -72,7 +72,6 @@ def load_general_plugins():
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives