mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 04:57:03 +08:00
[Feature] Batch invariant: Enable TRITON_MLA without prefix-caching (#29125)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
9d6235ca9a
commit
d9417096d1
@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
# enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
@ -928,7 +925,6 @@ def LLM_with_max_seqs(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
}
|
||||
|
||||
tp_size = os.getenv("VLLM_TP_SIZE", "1")
|
||||
server_args: list[str] = []
|
||||
server_args: list[str] = [
|
||||
"--max-model-len=8192",
|
||||
"--max-num-seqs=32",
|
||||
]
|
||||
if tp_size:
|
||||
server_args += ["-tp", tp_size]
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"TRITON_MLA",
|
||||
]
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
UnquantizedLinearMethod,
|
||||
@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# prefix caching + batch invariance is currently not supported for
|
||||
# FLASHINFER and TRITON_MLA.
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "FLASHINFER"
|
||||
or self.attn_backend.get_name() == "TRITON_MLA"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "TRITON_MLA"
|
||||
or self.attn_backend.get_name() == "FLASHINFER"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||
self.impl = impl_cls(
|
||||
num_heads=self.num_heads,
|
||||
|
||||
@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
|
||||
# "TRITON_MLA",
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
error = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user