From d9417096d1347ead26b7c0cc1afb22fc0028e6e8 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:31:57 -0500 Subject: [PATCH] [Feature] Batch invariant: Enable `TRITON_MLA` without prefix-caching (#29125) Signed-off-by: yewentao256 --- tests/v1/determinism/test_batch_invariance.py | 6 +--- .../test_online_batch_invariance.py | 5 ++- tests/v1/determinism/utils.py | 1 + vllm/attention/layer.py | 36 +++++++++++++++++++ vllm/model_executor/layers/batch_invariant.py | 2 +- 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 4311547baccf1..fc953a66f0820 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, + # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported @@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): gpu_memory_utilization=0.9, max_model_len=2048, dtype="bfloat16", - enable_prefix_caching=False, ) prompt = "the capital of france is" @@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", @@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", @@ -928,7 +925,6 @@ def LLM_with_max_seqs( max_model_len=max_model_len, dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), - enable_prefix_caching=False, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index d74b435797f8f..5e3b997364949 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( } tp_size = os.getenv("VLLM_TP_SIZE", "1") - server_args: list[str] = [] + server_args: list[str] = [ + "--max-model-len=8192", + "--max-num-seqs=32", + ] if tp_size: server_args += ["-tp", tp_size] diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 0d7da107728b4..6aab50cf84ab6 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif( BACKENDS: list[str] = [ "FLASH_ATTN", + "TRITON_MLA", ] if has_flashinfer(): diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 340b161ea1e15..7e5adfe0742d3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, UnquantizedLinearMethod, @@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase): else: self.attn_backend = attn_backend + # prefix caching + batch invariance is currently not supported for + # FLASHINFER and TRITON_MLA. + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "FLASHINFER" + or self.attn_backend.get_name() == "TRITON_MLA" + ) + ): + logger.warning_once( + "Disabling prefix caching for FLASHINFER/TRITON_MLA " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls( num_heads, @@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase): use_mla=True, use_sparse=use_sparse, ) + + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "TRITON_MLA" + or self.attn_backend.get_name() == "FLASHINFER" + ) + ): + logger.warning_once( + "Disabling prefix caching for TRITON_MLA / FLASHINFER " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 4154122636dcf..4cab47f4192a2 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1006,11 +1006,11 @@ def override_envs_for_invariance(): "FLASH_ATTN", # best supported backend "FLASHINFER", "FLASH_ATTN_MLA", + "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 - # "TRITON_MLA", ] if curr_attn_backend not in supported_backends: error = (