mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-17 04:34:28 +08:00
[Bug] Fix 'CutlassMLAImpl' object has no attribute '_workspace_buffer' (#31173)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
de71747655
commit
5312a7284e
@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
|
||||
max_query_len: int
|
||||
chunked_context: ChunkedContextMetadata | None = None
|
||||
query_seq_lens: torch.Tensor | None = None
|
||||
workspace_buffer: torch.Tensor | None = None
|
||||
q_data_type: torch.dtype | None = None
|
||||
|
||||
|
||||
@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill_metadata.query_seq_lens = (
|
||||
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
|
||||
)
|
||||
prefill_metadata.workspace_buffer = self._workspace_buffer
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
|
||||
@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
workspace_buffer=prefill.workspace_buffer,
|
||||
seq_lens=prefill.query_seq_lens,
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.max_query_len,
|
||||
@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
out = torch.zeros(
|
||||
q.shape[0],
|
||||
@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
self._workspace_buffer.fill_(0)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
|
||||
@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
workspace_buffer=prefill.workspace_buffer,
|
||||
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
|
||||
max_q_len=prefill.max_query_len,
|
||||
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user