diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 2517a59718382..a3dec0dbda9f8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1091,7 +1091,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q, k, maybe_padded_v, - **kwargs, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias ) if is_vllm_fa: attn_out = self.flash_attn_varlen_func(