From 3cfa63ad991665b2440155cd29352342024072fd Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 25 Nov 2025 05:02:21 +0800 Subject: [PATCH] [XPU]fix Kimi-VL-A3B-thinking on xpu (#29309) Signed-off-by: Yan Ma --- vllm/model_executor/models/moonvit.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 2e3e6dc166ad..63ea6b259a71 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.utils import maybe_prefix +from vllm.platforms import current_platform from vllm.transformers_utils.configs.moonvit import MoonViTConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func +elif current_platform.is_xpu(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func else: flash_attn_varlen_func = None @@ -106,10 +109,10 @@ def multihead_attention( q, k, v, - q_cu_seqlens, - k_cu_seqlens, - max_seqlen_q, - max_seqlen_k, + cu_seqlens_q=q_cu_seqlens, + cu_seqlens_k=k_cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, causal=False, ) attn_out = attn_out.flatten(start_dim=-2) @@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module): """ def __init__( - self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda" + self, + dim: int, + max_height: int, + max_width: int, + theta_base=10000, + device=current_platform.device_type, ): super().__init__() self.dim = dim @@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module): self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads self.attn_implementation = attn_implementation # use fa2 in vllm by default - if is_flash_attn_2_available(): + if is_flash_attn_2_available() or current_platform.is_xpu(): self.attn_implementation = "flash_attention_2" self.norm0 = nn.LayerNorm(hidden_dim)