From d52c5096d7305abc7f266026cb042121ff5bccda Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Sat, 20 Dec 2025 09:03:35 +0800 Subject: [PATCH] [Bugfix] fix the alias bug of AttentionBackendEnum when register CUSTOM attention backend to vllm (#30869) Signed-off-by: zejunchen-zejun --- tests/test_attention_backend_registry.py | 169 +++++++++++++++++++++++ vllm/attention/backends/registry.py | 6 +- 2 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 tests/test_attention_backend_registry.py diff --git a/tests/test_attention_backend_registry.py b/tests/test_attention_backend_registry.py new file mode 100644 index 0000000000000..7b90b949aa457 --- /dev/null +++ b/tests/test_attention_backend_registry.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, +) +from vllm.attention.backends.registry import ( + AttentionBackendEnum, + MambaAttentionBackendEnum, + register_backend, +) + + +class CustomAttentionImpl(AttentionImpl): + """Mock custom attention implementation for testing.""" + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, *args, **kwargs): + """Mock forward pass.""" + pass + + +class CustomAttentionBackend(AttentionBackend): + """Mock custom attention backend for testing.""" + + @staticmethod + def get_name(): + return "CUSTOM" + + @staticmethod + def get_impl_cls(): + return CustomAttentionImpl + + @staticmethod + def get_builder_cls(): + """Mock builder class.""" + return None + + @staticmethod + def get_required_kv_cache_layout(): + """Mock KV cache layout.""" + return None + + +class CustomMambaAttentionImpl(AttentionImpl): + """Mock custom mamba attention implementation for testing.""" + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, *args, **kwargs): + """Mock forward pass.""" + pass + + +class CustomMambaAttentionBackend(AttentionBackend): + """Mock custom mamba attention backend for testing.""" + + @staticmethod + def get_name(): + return "CUSTOM_MAMBA" + + @staticmethod + def get_impl_cls(): + return CustomMambaAttentionImpl + + @staticmethod + def get_builder_cls(): + """Mock builder class.""" + return None + + @staticmethod + def get_required_kv_cache_layout(): + """Mock KV cache layout.""" + return None + + +def test_custom_is_not_alias_of_any_backend(): + # Get all members of AttentionBackendEnum + all_backends = list(AttentionBackendEnum) + + # Find any aliases of CUSTOM + aliases = [] + for backend in all_backends: + if backend.name != "CUSTOM" and backend is AttentionBackendEnum.CUSTOM: + aliases.append(backend.name) + + # CUSTOM should not be an alias of any other backend + assert len(aliases) == 0, ( + f"BUG! CUSTOM is an alias of: {', '.join(aliases)}!\n" + f"CUSTOM.value = {repr(AttentionBackendEnum.CUSTOM.value)}\n" + f"This happens when CUSTOM has the same value as another backend.\n" + f"When you register to CUSTOM, you're actually registering to {aliases[0]}!\n" + f"All backend values:\n" + + "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends) + ) + + # Verify CUSTOM has its own unique identity + assert AttentionBackendEnum.CUSTOM.name == "CUSTOM", ( + f"CUSTOM.name should be 'CUSTOM', but got '{AttentionBackendEnum.CUSTOM.name}'" + ) + + +def test_register_custom_backend_with_class_path(): + # Register with explicit class path + register_backend( + backend=AttentionBackendEnum.CUSTOM, + class_path="tests.test_attention_backend_registry.CustomAttentionBackend", + is_mamba=False, + ) + + # Check that CUSTOM backend is registered + assert AttentionBackendEnum.CUSTOM.is_overridden(), ( + "CUSTOM should be overridden after registration" + ) + + # Get the registered class path + class_path = AttentionBackendEnum.CUSTOM.get_path() + assert class_path == "tests.test_attention_backend_registry.CustomAttentionBackend" + + # Get the backend class + backend_cls = AttentionBackendEnum.CUSTOM.get_class() + assert backend_cls.get_name() == "CUSTOM" + assert backend_cls.get_impl_cls() == CustomAttentionImpl + + +def test_mamba_custom_is_not_alias_of_any_backend(): + # Get all mamba backends + all_backends = list(MambaAttentionBackendEnum) + + # Find any aliases of CUSTOM + aliases = [] + for backend in all_backends: + if backend.name != "CUSTOM" and backend is MambaAttentionBackendEnum.CUSTOM: + aliases.append(backend.name) + + # CUSTOM should not be an alias of any other backend + assert len(aliases) == 0, ( + f"BUG! MambaAttentionBackendEnum.CUSTOM is an alias of: {', '.join(aliases)}!\n" + f"CUSTOM.value = {repr(MambaAttentionBackendEnum.CUSTOM.value)}\n" + f"All mamba backend values:\n" + + "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends) + ) + + +def test_register_custom_mamba_backend_with_class_path(): + # Register with explicit class path + register_backend( + backend=MambaAttentionBackendEnum.CUSTOM, + class_path="tests.test_attention_backend_registry.CustomMambaAttentionBackend", + is_mamba=True, + ) + + # Check that the backend is registered + assert MambaAttentionBackendEnum.CUSTOM.is_overridden() + + # Get the registered class path + class_path = MambaAttentionBackendEnum.CUSTOM.get_path() + assert ( + class_path + == "tests.test_attention_backend_registry.CustomMambaAttentionBackend" + ) + + # Get the backend class + backend_cls = MambaAttentionBackendEnum.CUSTOM.get_class() + assert backend_cls.get_name() == "CUSTOM_MAMBA" + assert backend_cls.get_impl_cls() == CustomMambaAttentionImpl diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index ed0021db204ac..416b996df9f22 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -77,7 +77,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use - CUSTOM = "" + # set to None to avoid alias with other backend, whose value is an empty string + CUSTOM = None def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides). @@ -139,7 +140,8 @@ class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend" GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use - CUSTOM = "" + # set to None to avoid alias with other backend, whose value is an empty string + CUSTOM = None def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides).