diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7a6edc0add8bf..619c1c2794daa 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -355,6 +355,7 @@ class MLACommonPrefillMetadata: max_query_len: int chunked_context: ChunkedContextMetadata | None = None query_seq_lens: torch.Tensor | None = None + workspace_buffer: torch.Tensor | None = None q_data_type: torch.dtype | None = None @@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): prefill_metadata.query_seq_lens = ( prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] ) + prefill_metadata.workspace_buffer = self._workspace_buffer decode_metadata = None if num_decodes > 0: @@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.query_seq_lens is not None + assert prefill.workspace_buffer is not None if fp8_attention: logger.debug_once("Running TRT-LLM ragged prefill in FP8") @@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): query=q, key=k, value=v, - workspace_buffer=self._workspace_buffer, + workspace_buffer=prefill.workspace_buffer, seq_lens=prefill.query_seq_lens, max_q_len=prefill.max_query_len, max_kv_len=prefill.max_query_len, @@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.workspace_buffer is not None out = torch.zeros( q.shape[0], @@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): device=q.device, dtype=q.dtype, ) - self._workspace_buffer.fill_(0) + prefill.workspace_buffer.fill_(0) if fp8_attention: logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8") @@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): query=q, key=k, value=v, - workspace_buffer=self._workspace_buffer, + workspace_buffer=prefill.workspace_buffer, seq_lens=prefill.chunked_context.seq_lens[chunk_idx], max_q_len=prefill.max_query_len, max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],