[Feature] Batch invariant: Enable TRITON_MLA without prefix-caching (#29125)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-12-08 19:31:57 -05:00 committed by GitHub
parent 9d6235ca9a
commit d9417096d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 7 deletions

View File

@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
)
prompt = "the capital of france is"
@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
# Enable for MOE models
# enable_expert_parallel=True,
)

View File

@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
}
tp_size = os.getenv("VLLM_TP_SIZE", "1")
server_args: list[str] = []
server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
]
if tp_size:
server_args += ["-tp", tp_size]

View File

@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_MLA",
]
if has_flashinfer():

View File

@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
else:
self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,

View File

@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
error = (