mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 15:37:02 +08:00
[Hardware][Intel-Gaudi] Enable FusedSDPA support for Intel Gaudi (HPU)
This commit is contained in:
parent
4c3aac51e1
commit
af8486de49
@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import vllm_hpu_extension.ops as ops
|
import vllm_hpu_extension.ops as ops
|
||||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
||||||
|
VLLMKVCache)
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
@ -137,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
|
|
||||||
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
||||||
'0').lower() in ['1', 'true']
|
'0').lower() in ['1', 'true']
|
||||||
|
self.fused_scaled_dot_product_attention = None
|
||||||
if self.prefill_usefusedsdpa:
|
if self.prefill_usefusedsdpa:
|
||||||
assert alibi_slopes is None, \
|
assert alibi_slopes is None, \
|
||||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||||
|
try:
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
|
||||||
|
FusedSDPA)
|
||||||
|
except ImportError:
|
||||||
|
logger().warning("Could not import HPU FusedSDPA kernel. "
|
||||||
|
"vLLM will use native implementation.")
|
||||||
|
|
||||||
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
||||||
if head_size not in suppored_head_sizes:
|
if head_size not in suppored_head_sizes:
|
||||||
@ -227,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
matmul_qk_op=self.matmul_qk,
|
matmul_qk_op=self.matmul_qk,
|
||||||
softmax_op=self.softmax,
|
softmax_op=self.softmax,
|
||||||
matmul_av_op=self.matmul_av,
|
matmul_av_op=self.matmul_av,
|
||||||
|
fsdpa_op=self.fused_scaled_dot_product_attention,
|
||||||
)
|
)
|
||||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user