[Bug] Fix 'CutlassMLAImpl' object has no attribute '_workspace_buffer' (#31173)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-12-22 17:24:27 -05:00 committed by GitHub
parent de71747655
commit 5312a7284e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
)
prefill_metadata.workspace_buffer = self._workspace_buffer
decode_metadata = None
if num_decodes > 0:
@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len,
@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros(
q.shape[0],
@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
device=q.device,
dtype=q.dtype,
)
self._workspace_buffer.fill_(0)
prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],