mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 06:17:13 +08:00
fix full cudagraphs for cutlass mla
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
090f485aa1
commit
143b09e6be
@ -14,18 +14,15 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
|||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
MLACommonMetadataBuilder)
|
MLACommonMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
# enable full CUDA Graph support for decode-only capture
|
# enable full CUDA Graph support for decode-only capture
|
||||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
attn_cudagraph_support: ClassVar[
|
||||||
|
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
|
||||||
def can_run_in_cudagraph(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
|
||||||
return common_attn_metadata.max_query_len == 1
|
|
||||||
|
|
||||||
|
|
||||||
class CutlassMLABackend(MLACommonBackend):
|
class CutlassMLABackend(MLACommonBackend):
|
||||||
|
|||||||
@ -2031,7 +2031,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
<<<<<<< HEAD
|
|
||||||
# when DBO is enabled, `num_tokens_after_padding`
|
# when DBO is enabled, `num_tokens_after_padding`
|
||||||
# represents the per-ubatch DP token count.
|
# represents the per-ubatch DP token count.
|
||||||
dp_tokens_for_forward = num_tokens_after_padding
|
dp_tokens_for_forward = num_tokens_after_padding
|
||||||
@ -2045,9 +2044,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_tokens_across_dp=dp_tokens_for_forward,
|
num_tokens_across_dp=dp_tokens_for_forward,
|
||||||
skip_cuda_graphs=skip_cuda_graphs):
|
skip_cuda_graphs=skip_cuda_graphs):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
=======
|
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
|
||||||
>>>>>>> db77e4a92 (revert kv connector fix)
|
|
||||||
model_output = self._run_model(
|
model_output = self._run_model(
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
num_scheduled_tokens=num_input_tokens,
|
num_scheduled_tokens=num_input_tokens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user