mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[XPU]fix Kimi-VL-A3B-thinking on xpu (#29309)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
4d6afcaddc
commit
3cfa63ad99
@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available
|
|||||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
@ -106,10 +109,10 @@ def multihead_attention(
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
q_cu_seqlens,
|
cu_seqlens_q=q_cu_seqlens,
|
||||||
k_cu_seqlens,
|
cu_seqlens_k=k_cu_seqlens,
|
||||||
max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k=max_seqlen_k,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
attn_out = attn_out.flatten(start_dim=-2)
|
attn_out = attn_out.flatten(start_dim=-2)
|
||||||
@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
|
self,
|
||||||
|
dim: int,
|
||||||
|
max_height: int,
|
||||||
|
max_width: int,
|
||||||
|
theta_base=10000,
|
||||||
|
device=current_platform.device_type,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module):
|
|||||||
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
||||||
self.attn_implementation = attn_implementation
|
self.attn_implementation = attn_implementation
|
||||||
# use fa2 in vllm by default
|
# use fa2 in vllm by default
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available() or current_platform.is_xpu():
|
||||||
self.attn_implementation = "flash_attention_2"
|
self.attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
self.norm0 = nn.LayerNorm(hidden_dim)
|
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user