fix full cudagraphs for cutlass mla

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-08-13 15:00:40 -04:00
parent 090f485aa1
commit 143b09e6be
2 changed files with 3 additions and 10 deletions

View File

@ -14,18 +14,15 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
class CutlassMLABackend(MLACommonBackend):

View File

@ -2031,7 +2031,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Run the model.
# Use persistent buffers for CUDA graphs.
<<<<<<< HEAD
# when DBO is enabled, `num_tokens_after_padding`
# represents the per-ubatch DP token count.
dp_tokens_for_forward = num_tokens_after_padding
@ -2045,9 +2044,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_across_dp=dp_tokens_for_forward,
skip_cuda_graphs=skip_cuda_graphs):
self.maybe_setup_kv_connector(scheduler_output)
=======
self.maybe_setup_kv_connector(scheduler_output)
>>>>>>> db77e4a92 (revert kv connector fix)
model_output = self._run_model(
attn_metadata=attn_metadata,
num_scheduled_tokens=num_input_tokens,