From 799ce45cc160ffc0a3e1a0f958cc8e260b751808 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 1 Nov 2025 10:02:23 +0000 Subject: [PATCH 001/231] [Docs] Mock all imports for docs (#27873) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/mkdocs/hooks/generate_argparse.py | 60 ++++++++++++++++++++------ requirements/docs.txt | 8 ---- vllm/utils/cache.py | 4 +- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ea89108f01fc2..ce1c5c53cf35a 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -3,6 +3,7 @@ import importlib import logging import sys +import traceback from argparse import SUPPRESS, HelpFormatter from pathlib import Path from typing import Literal @@ -16,7 +17,30 @@ ROOT_DIR = Path(__file__).parent.parent.parent.parent ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) + + +# Mock custom op code +class MockCustomOp: + @staticmethod + def register(name): + def decorator(cls): + return cls + + return decorator + + +noop = lambda *a, **k: None sys.modules["vllm._C"] = MagicMock() +sys.modules["vllm.model_executor.custom_op"] = MagicMock(CustomOp=MockCustomOp) +sys.modules["vllm.utils.torch_utils"] = MagicMock(direct_register_custom_op=noop) + +# Mock any version checks by reading from compiled CI requirements +with open(ROOT_DIR / "requirements/test.txt") as f: + VERSIONS = dict(line.strip().split("==") for line in f if "==" in line) +importlib.metadata.version = lambda name: VERSIONS.get(name) or "0.0.0" + +# Make torch.nn.Parameter safe to inherit from +sys.modules["torch.nn"] = MagicMock(Parameter=object) class PydanticMagicMock(MagicMock): @@ -31,20 +55,17 @@ class PydanticMagicMock(MagicMock): return core_schema.any_schema() -def auto_mock(module, attr, max_mocks=50): +def auto_mock(module, attr, max_mocks=100): """Function that automatically mocks missing modules during imports.""" logger.info("Importing %s from %s", attr, module) for _ in range(max_mocks): try: # First treat attr as an attr, then as a submodule - with patch("importlib.metadata.version", return_value="0.0.0"): - return getattr( - importlib.import_module(module), - attr, - importlib.import_module(f"{module}.{attr}"), - ) - except importlib.metadata.PackageNotFoundError as e: - raise e + return getattr( + importlib.import_module(module), + attr, + importlib.import_module(f"{module}.{attr}"), + ) except ModuleNotFoundError as e: logger.info("Mocking %s for argparse doc generation", e.name) sys.modules[e.name] = PydanticMagicMock(name=e.name) @@ -139,10 +160,19 @@ def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: Returns: FlexibleArgumentParser: A parser with markdown formatting for the class. """ - parser = FlexibleArgumentParser(add_json_tip=False) - parser.formatter_class = MarkdownFormatter - with patch("vllm.config.DeviceConfig.__post_init__"): - _parser = add_cli_args(parser, **kwargs) + try: + parser = FlexibleArgumentParser(add_json_tip=False) + parser.formatter_class = MarkdownFormatter + with patch("vllm.config.DeviceConfig.__post_init__"): + _parser = add_cli_args(parser, **kwargs) + except ModuleNotFoundError as e: + # Auto-mock runtime imports + if tb_list := traceback.extract_tb(e.__traceback__): + path = Path(tb_list[-1].filename).relative_to(ROOT_DIR) + auto_mock(module=".".join(path.parent.parts), attr=path.stem) + return create_parser(add_cli_args, **kwargs) + else: + raise e # add_cli_args might be in-place so return parser if _parser is None return _parser or parser @@ -184,3 +214,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): with open(doc_path, "w", encoding="utf-8") as f: f.write(super(type(parser), parser).format_help()) logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) + + +if __name__ == "__main__": + on_startup("build", False) diff --git a/requirements/docs.txt b/requirements/docs.txt index 00c314874016f..0fd6dbe22c512 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -9,12 +9,4 @@ mkdocs-git-revision-date-localized-plugin mkdocs-minify-plugin regex ruff - -# Required for argparse hook only --f https://download.pytorch.org/whl/cpu -cachetools -cloudpickle -py-cpuinfo -msgspec pydantic -torch diff --git a/vllm/utils/cache.py b/vllm/utils/cache.py index d5e08caa8a1ed..4338983f90601 100644 --- a/vllm/utils/cache.py +++ b/vllm/utils/cache.py @@ -3,7 +3,7 @@ from collections import UserDict from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping from types import MappingProxyType -from typing import Generic, NamedTuple, TypeVar, cast, overload +from typing import NamedTuple, TypeVar, cast, overload import cachetools @@ -48,7 +48,7 @@ class CacheInfo(NamedTuple): ) -class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): +class LRUCache(cachetools.LRUCache[_K, _V]): def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): super().__init__(capacity, getsizeof) From 30a14b034fa387470a512e8004527ad1c28af303 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Sat, 1 Nov 2025 18:17:45 +0800 Subject: [PATCH 002/231] [V0 deprecation] Remove VLLM_USE_V1 usage in platform and v1 module (#27798) Signed-off-by: wangxiyuan Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/platforms/cuda.py | 190 ++++++++++++--------------- vllm/platforms/interface.py | 9 +- vllm/platforms/rocm.py | 84 +++++------- vllm/platforms/tpu.py | 4 - vllm/platforms/xpu.py | 9 +- vllm/v1/engine/async_llm.py | 16 --- vllm/v1/engine/llm_engine.py | 11 +- vllm/v1/executor/uniproc_executor.py | 9 +- 8 files changed, 128 insertions(+), 204 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cc06f034fba32..32734c3aba5ef 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -276,17 +276,12 @@ class CudaPlatformBase(Platform): "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " "VLLM_MLA_DISABLE=1 to disable MLA for this model." ) - if not use_v1: - raise RuntimeError( - "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them." - ) from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla if use_sparse: - logger.info_once("Using Sparse MLA backend on V1 engine.") + logger.info_once("Using Sparse MLA backend.") return ( "vllm.v1.attention.backends.mla.flashmla_sparse." "FlashMLASparseBackend" @@ -313,15 +308,13 @@ class CudaPlatformBase(Platform): ) if use_cutlassmla: - logger.info_once( - "Using Cutlass MLA backend on V1 engine.", scope="local" - ) + logger.info_once("Using Cutlass MLA backend.", scope="local") return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" if use_flashinfermla: from vllm.v1.attention.backends.utils import set_kv_cache_layout set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend on V1 engine.") + logger.info_once("Using FlashInfer MLA backend.") return ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) @@ -333,116 +326,107 @@ class CudaPlatformBase(Platform): block_size, ) else: - logger.info_once("Using FlashMLA backend on V1 engine.") + logger.info_once("Using FlashMLA backend.") return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" if use_flashattn: - logger.info_once("Using FlashAttention MLA backend on V1 engine.") + logger.info_once("Using FlashAttention MLA backend.") return ( "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" ) if use_triton: - logger.info_once("Using Triton MLA backend on V1 engine.") + logger.info_once("Using Triton MLA backend.") return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - if use_v1: - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - ) - TRITON_ATTN = ( - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - ) - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" - ) + FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 + FLEX_ATTENTION_V1 = ( + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + ) + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend on V1 engine.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout + use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( + "fp8" + ) + + if selected_backend == _Backend.FLASHINFER: + logger.info_once("Using FlashInfer backend.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("HND") + return FLASHINFER_V1 + elif selected_backend == _Backend.FLEX_ATTENTION: + logger.info_once("Using FlexAttention backend.") + return FLEX_ATTENTION_V1 + elif selected_backend == _Backend.TRITON_ATTN: + logger.info_once("Using Triton backend.") + return TRITON_ATTN + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend.") + return FLASH_ATTN_V1 + elif selected_backend == _Backend.TREE_ATTN: + logger.info_once("Using Tree Attention backend.") + return TREE_ATTN_V1 + elif selected_backend == _Backend.XFORMERS: + logger.info_once("Using XFormers backend.") + return XFORMERS_V1 + + from vllm.attention.selector import is_attn_backend_supported + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed + if cls.is_device_capability(100): + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype + ): + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + logger.info_once( + "Using FlashInfer backend with HND KV cache layout on " + "V1 engine by default for Blackwell (SM 10.0) GPUs." + ) + set_kv_cache_layout("HND") - set_kv_cache_layout("HND") return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend on V1 engine.") + + if not is_default_backend_supported.can_import: + logger.warning_once( + "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; " + "it is recommended to install FlashInfer for better " + "performance." + ) + + # FlashAttention is the default for SM 8.0+ GPUs + if cls.has_device_capability(80): + if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): + logger.info_once("Using Triton backend.") return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend on V1 engine.") + elif is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention backend.") return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend on V1 engine.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend on V1 engine.") - return XFORMERS_V1 - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import for V1 engine on " - "Blackwell (SM 10.0) GPUs; it is recommended to " - "install FlashInfer for better performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 - - assert not is_default_backend_supported - - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype - - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), - ) + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend.") return FLEX_ATTENTION_V1 - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." + assert not is_default_backend_supported + + use_flex_attention_reason = {} + if not is_default_backend_supported.head_size: + use_flex_attention_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + use_flex_attention_reason["dtype"] = dtype + + logger.info_once( + "Using FlexAttention backend for %s.", + ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), ) + return FLEX_ATTENTION_V1 @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4462829564391..15e3b3a22bdee 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -467,14 +467,7 @@ class Platform: """ Whether to use allgather in LogitsProcessor to gather the logits. """ - import vllm.envs as envs - from vllm.config import get_current_vllm_config - - parallel_config = get_current_vllm_config().parallel_config - return ( - envs.VLLM_USE_V1 - or parallel_config.distributed_executor_backend == "external_launcher" - ) + return True @classmethod def use_custom_allreduce(cls) -> bool: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d3535c9781c48..0c03a5564db89 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -149,7 +149,7 @@ def use_rocm_custom_paged_attention( # disabled due to observed numerical discrepancy. if ON_GFX9: return ( - (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1)) + (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) @@ -163,11 +163,7 @@ def use_rocm_custom_paged_attention( else: return ( ON_GFX11_GFX12 - and ( - not envs.VLLM_USE_V1 - or sliding_window == 0 - or sliding_window == (-1, -1) - ) + and (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 @@ -236,12 +232,6 @@ class RocmPlatform(Platform): if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") if use_mla: - if not use_v1: - raise RuntimeError( - "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them." - ) - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled, ) @@ -255,7 +245,7 @@ class RocmPlatform(Platform): if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info_once("Using Triton MLA backend on V1 engine.") + logger.info_once("Using Triton MLA backend.") return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" raise ValueError( f" The selected backend, {selected_backend.name}," @@ -263,7 +253,7 @@ class RocmPlatform(Platform): ) if selected_backend == _Backend.ROCM_AITER_MLA: if block_size == 1: - logger.info("Using AITER MLA backend on V1 engine.") + logger.info("Using AITER MLA backend.") return ( "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 ) @@ -277,41 +267,33 @@ class RocmPlatform(Platform): f"is not MLA type while requested for MLA backend." ) - if envs.VLLM_USE_V1: - if selected_backend == _Backend.FLEX_ATTENTION: - logger.info("Using FlexAttention backend on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() - ) or selected_backend == _Backend.ROCM_AITER_FA: - logger.info("Using Aiter Flash Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend" - ) - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) - if ( - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN - ): - # rocm specific backend, with aiter and/or - # triton prefix-prefill - logger.info("Using Rocm Attention backend on V1 engine.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" - # default case, using triton unified attention - logger.info("Using Triton Attention backend on V1 engine.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." - ) + if selected_backend == _Backend.FLEX_ATTENTION: + logger.info("Using FlexAttention backend.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + ) or selected_backend == _Backend.ROCM_AITER_FA: + logger.info("Using Aiter Flash Attention backend.") + return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend.") + return ( + "vllm.v1.attention.backends." + "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" + ) + if ( + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + or selected_backend == _Backend.ROCM_ATTN + ): + # rocm specific backend, with aiter and/or + # triton prefix-prefill + logger.info("Using Rocm Attention backend.") + return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + # default case, using triton unified attention + logger.info("Using Triton Attention backend.") + return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" @classmethod def set_device(cls, device: torch.device) -> None: @@ -372,7 +354,6 @@ class RocmPlatform(Platform): parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE - use_v1 = envs.VLLM_USE_V1 use_aiter_rms_norm = ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM ) @@ -384,8 +365,7 @@ class RocmPlatform(Platform): parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" # Aiter rms norm perform best when CUDA Graph capture is enabled. if ( - use_v1 - and use_aiter_rms_norm + use_aiter_rms_norm and not is_eager_execution and "-rms_norm" not in compilation_config.custom_ops ): diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0a14ee011f7f2..1a4b67a1762f3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -204,10 +204,6 @@ class TpuPlatform(Platform): def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa - @classmethod - def use_all_gather(cls) -> bool: - return True - @classmethod def validate_request( cls, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 07ab759e4baa6..e4ecd0c807dac 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -66,16 +66,13 @@ class XPUPlatform(Platform): if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") - use_v1 = envs.VLLM_USE_V1 - if not use_v1: - raise ValueError("XPU backend only supports V1.") TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend on V1 engine.") + logger.info_once("Using Triton backend.") return TRITON_ATTN elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend on V1 engine.") + logger.info_once("Using Flash Attention backend.") return FLASH_ATTN elif selected_backend: raise ValueError( @@ -83,7 +80,7 @@ class XPUPlatform(Platform): f"with use_v1: {use_v1} use_mla: {use_mla}" ) - logger.info("Using Flash Attention backend on V1 engine.") + logger.info("Using Flash Attention backend.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" @classmethod diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index dc61d45015682..f0d5b77e8e183 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -88,14 +88,6 @@ class AsyncLLM(EngineClient): Returns: None """ - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - # Ensure we can serialize custom transformer configs maybe_register_config_serialize_by_value() @@ -206,14 +198,6 @@ class AsyncLLM(EngineClient): client_index: int = 0, disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - # Create the LLMEngine. return cls( vllm_config=vllm_config, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c2ca9579d55ea..f44b6b2070d9f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -58,18 +58,9 @@ class LLMEngine: use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - if stat_loggers is not None: raise NotImplementedError( - "Passing StatLoggers to LLMEngine in V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github." + "Passing StatLoggers to LLMEngine is not yet supported." ) self.vllm_config = vllm_config diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index f17d3c3092701..32f00949b7f74 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -124,11 +124,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" - if envs.VLLM_USE_V1: - assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( - "To get deterministic execution in V1, " - "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" - ) + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( + "To get deterministic execution, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" + ) super()._init_executor() def _distributed_args(self) -> tuple[str, int, int]: From d811b442d305b33b3aca2836c5d7f761effe76de Mon Sep 17 00:00:00 2001 From: Haco <75477391+xiaohajiayou@users.noreply.github.com> Date: Sat, 1 Nov 2025 22:52:43 +0800 Subject: [PATCH 003/231] [Bugfix] DeepSeek V3.2 MTP metadata & CUDA graph issues (#26779) Signed-off-by: xiaohajiayou <923390377@qq.com> --- vllm/v1/spec_decode/eagle.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 35c2e73e8ee2c..1e18eea2330a4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -109,6 +109,7 @@ class EagleProposer: else [] ) + self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) # persistent buffers for cuda graph self.input_ids = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=device @@ -939,7 +940,7 @@ class EagleProposer: self.vllm_config, DeepseekV32IndexerCache ) draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names - self.attn_layer_names = list(draft_attn_layer_names) + self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: @@ -1050,16 +1051,18 @@ class EagleProposer: num_tokens: int, use_cudagraphs=True, ) -> None: - if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: + # Determine if CUDA graphs should be used for this run. + cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph + if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) with set_forward_context( None, self.vllm_config, num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if use_cudagraphs - else CUDAGraphMode.NONE, + cudagraph_runtime_mode=( + CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE + ), ): if self.supports_mm_inputs: input_ids = None From 99d69af9ece094acb94901439925f8468b32326a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 1 Nov 2025 23:28:54 +0800 Subject: [PATCH 004/231] [Bugfix] Python 3.10 compatibility for `Self` (#27918) Signed-off-by: DarkLight1337 --- vllm/config/structured_outputs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 85b6e42264a42..eb1cc7220b8fe 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import Any, Literal, Self +from typing import Any, Literal from pydantic import model_validator from pydantic.dataclasses import dataclass +from typing_extensions import Self from vllm.config.utils import config From af6e19f50f1d5d0c3801948c3ab17b2af231c259 Mon Sep 17 00:00:00 2001 From: wenxindongwork <161090399+wenxindongwork@users.noreply.github.com> Date: Sat, 1 Nov 2025 11:14:44 -0600 Subject: [PATCH 005/231] [Core][TPU] Support TPU Data Parallalism (#27365) Signed-off-by: wenxindongwork --- vllm/entrypoints/llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 758e16c89e694..b0b996ab2fec5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -67,6 +67,7 @@ from vllm.outputs import ( RequestOutput, ScoringRequestOutput, ) +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask @@ -289,7 +290,11 @@ class LLM: # warn about single-process data parallel usage. _dp_size = int(kwargs.get("data_parallel_size", 1)) _distributed_executor_backend = kwargs.get("distributed_executor_backend") - if _dp_size > 1 and not _distributed_executor_backend == "external_launcher": + if ( + _dp_size > 1 + and not _distributed_executor_backend == "external_launcher" + and not current_platform.is_tpu() + ): raise ValueError( f"LLM(data_parallel_size={_dp_size}) is not supported for single-" "process usage and may hang. Please use " From c2ed069b32e2805c05a858c6157f4c6393b145a8 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 1 Nov 2025 10:51:24 -0700 Subject: [PATCH 006/231] [BugFix] Fix mixed penalties batch with async scheduling (#27910) Signed-off-by: Nick Hill --- vllm/v1/sample/ops/penalties.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 898b90d41abae..241d9de957ea2 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -21,6 +21,14 @@ def apply_all_penalties( """ _, vocab_size = logits.shape output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) + + # In the async scheduling case, rows that won't have penalties applied may contain + # -1 placeholder token ids. We must replace these with valid token ids so that the + # scatter done in apply_penalties is valid. + # NOTE(nick): The penalties implementation is currently quite inefficient and + # will be reworked anyhow. + output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size) + return apply_penalties( logits, prompt_token_ids, From 1e88fb751bce13c74355d177fd06035858ce77c4 Mon Sep 17 00:00:00 2001 From: Benjamin Bartels Date: Sat, 1 Nov 2025 19:45:42 +0000 Subject: [PATCH 007/231] Adds anthropic /v1/messages endpoint to openai api_server (#27882) Signed-off-by: bbartels Signed-off-by: Benjamin Bartels --- tests/entrypoints/anthropic/__init__.py | 0 .../{anthropic => openai}/test_messages.py | 72 ++--- tests/utils.py | 142 +-------- vllm/entrypoints/anthropic/api_server.py | 301 ------------------ vllm/entrypoints/openai/api_server.py | 86 +++++ 5 files changed, 139 insertions(+), 462 deletions(-) delete mode 100644 tests/entrypoints/anthropic/__init__.py rename tests/entrypoints/{anthropic => openai}/test_messages.py (68%) delete mode 100644 vllm/entrypoints/anthropic/api_server.py diff --git a/tests/entrypoints/anthropic/__init__.py b/tests/entrypoints/anthropic/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/entrypoints/anthropic/test_messages.py b/tests/entrypoints/openai/test_messages.py similarity index 68% rename from tests/entrypoints/anthropic/test_messages.py rename to tests/entrypoints/openai/test_messages.py index 4e35554b4e330..3e390ad496428 100644 --- a/tests/entrypoints/anthropic/test_messages.py +++ b/tests/entrypoints/openai/test_messages.py @@ -5,7 +5,7 @@ import anthropic import pytest import pytest_asyncio -from ...utils import RemoteAnthropicServer +from ...utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen3-0.6B" @@ -23,13 +23,13 @@ def server(): # noqa: F811 "claude-3-7-sonnet-latest", ] - with RemoteAnthropicServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @pytest_asyncio.fixture async def client(server): - async with server.get_async_client() as async_client: + async with server.get_async_client_anthropic() as async_client: yield async_client @@ -105,37 +105,37 @@ async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): print(f"Anthropic response: {resp.model_dump_json()}") - @pytest.mark.asyncio - async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): - resp = await client.messages.create( - model="claude-3-7-sonnet-latest", - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "What's the weather like in New York today?", - } - ], - tools=[ - { - "name": "get_current_weather", - "description": "Useful for querying the weather " - "in a specified city.", - "input_schema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City or region, for example: " - "New York, London, Tokyo, etc.", - } - }, - "required": ["location"], - }, - } - ], - stream=True, - ) - async for chunk in resp: - print(chunk.model_dump_json()) +@pytest.mark.asyncio +async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?", + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } + }, + "required": ["location"], + }, + } + ], + stream=True, + ) + + async for chunk in resp: + print(chunk.model_dump_json()) diff --git a/tests/utils.py b/tests/utils.py index af4ce6ebaeda2..c8f18384c5114 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -247,6 +247,23 @@ class RemoteOpenAIServer: **kwargs, ) + def get_client_anthropic(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.Anthropic( + base_url=self.url_for(), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client_anthropic(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.AsyncAnthropic( + base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs + ) + class RemoteOpenAIServerCustom(RemoteOpenAIServer): """Launch test server with custom child process""" @@ -293,131 +310,6 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): self.proc.kill() -class RemoteAnthropicServer: - DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key - - def __init__( - self, - model: str, - vllm_serve_args: list[str], - *, - env_dict: dict[str, str] | None = None, - seed: int | None = 0, - auto_port: bool = True, - max_wait_seconds: float | None = None, - ) -> None: - if auto_port: - if "-p" in vllm_serve_args or "--port" in vllm_serve_args: - raise ValueError( - "You have manually specified the port when `auto_port=True`." - ) - - # Don't mutate the input args - vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] - if seed is not None: - if "--seed" in vllm_serve_args: - raise ValueError( - f"You have manually specified the seed when `seed={seed}`." - ) - - vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] - - parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.") - subparsers = parser.add_subparsers(required=False, dest="subparser") - parser = ServeSubcommand().subparser_init(subparsers) - args = parser.parse_args(["--model", model, *vllm_serve_args]) - self.host = str(args.host or "localhost") - self.port = int(args.port) - - self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None - - # download the model before starting the server to avoid timeout - is_local = os.path.isdir(model) - if not is_local: - engine_args = AsyncEngineArgs.from_cli_args(args) - model_config = engine_args.create_model_config() - load_config = engine_args.create_load_config() - - model_loader = get_model_loader(load_config) - model_loader.download_model(model_config) - - env = os.environ.copy() - # the current process might initialize cuda, - # to be safe, we should use spawn method - env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - if env_dict is not None: - env.update(env_dict) - self.proc = subprocess.Popen( - [ - sys.executable, - "-m", - "vllm.entrypoints.anthropic.api_server", - model, - *vllm_serve_args, - ], - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - max_wait_seconds = max_wait_seconds or 240 - self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.proc.terminate() - try: - self.proc.wait(8) - except subprocess.TimeoutExpired: - # force kill if needed - self.proc.kill() - - def _wait_for_server(self, *, url: str, timeout: float): - # run health check - start = time.time() - while True: - try: - if requests.get(url).status_code == 200: - break - except Exception: - # this exception can only be raised by requests.get, - # which means the server is not ready yet. - # the stack trace is not useful, so we suppress it - # by using `raise from None`. - result = self.proc.poll() - if result is not None and result != 0: - raise RuntimeError("Server exited unexpectedly.") from None - - time.sleep(0.5) - if time.time() - start > timeout: - raise RuntimeError("Server failed to start in time.") from None - - @property - def url_root(self) -> str: - return f"http://{self.host}:{self.port}" - - def url_for(self, *parts: str) -> str: - return self.url_root + "/" + "/".join(parts) - - def get_client(self, **kwargs): - if "timeout" not in kwargs: - kwargs["timeout"] = 600 - return anthropic.Anthropic( - base_url=self.url_for(), - api_key=self.DUMMY_API_KEY, - max_retries=0, - **kwargs, - ) - - def get_async_client(self, **kwargs): - if "timeout" not in kwargs: - kwargs["timeout"] = 600 - return anthropic.AsyncAnthropic( - base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs - ) - - def _test_completion( client: openai.OpenAI, model: str, diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py deleted file mode 100644 index df877f99b084f..0000000000000 --- a/vllm/entrypoints/anthropic/api_server.py +++ /dev/null @@ -1,301 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from: -# https://github.com/vllm/vllm/entrypoints/openai/api_server.py - -import asyncio -import signal -import tempfile -from argparse import Namespace -from http import HTTPStatus - -import uvloop -from fastapi import APIRouter, Depends, FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse -from starlette.datastructures import State - -import vllm.envs as envs -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.anthropic.protocol import ( - AnthropicErrorResponse, - AnthropicMessagesRequest, - AnthropicMessagesResponse, -) -from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client, - create_server_socket, - lifespan, - load_log_config, - validate_api_server_args, - validate_json_request, -) -from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.entrypoints.openai.protocol import ErrorResponse -from vllm.entrypoints.openai.serving_models import ( - BaseModelPath, - OpenAIServingModels, -) - -# -# yapf: enable -from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.utils import ( - cli_env_setup, - load_aware_call, - process_chat_template, - process_lora_modules, - with_cancellation, -) -from vllm.logger import init_logger -from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.utils.network_utils import is_valid_ipv6_address -from vllm.utils.system_utils import set_ulimit -from vllm.version import __version__ as VLLM_VERSION - -prometheus_multiproc_dir: tempfile.TemporaryDirectory - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger("vllm.entrypoints.anthropic.api_server") - -_running_tasks: set[asyncio.Task] = set() - -router = APIRouter() - - -def messages(request: Request) -> AnthropicServingMessages: - return request.app.state.anthropic_serving_messages - - -def engine_client(request: Request) -> EngineClient: - return request.app.state.engine_client - - -@router.get("/health", response_class=Response) -async def health(raw_request: Request) -> Response: - """Health check.""" - await engine_client(raw_request).check_health() - return Response(status_code=200) - - -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - -@router.post( - "/v1/messages", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): - handler = messages(raw_request) - if handler is None: - return messages(raw_request).create_error_response( - message="The model does not support Messages API" - ) - - generator = await handler.create_messages(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump()) - - elif isinstance(generator, AnthropicMessagesResponse): - logger.debug( - "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) - ) - return JSONResponse(content=generator.model_dump(exclude_none=True)) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -async def init_app_state( - engine_client: EngineClient, - state: State, - args: Namespace, -) -> None: - vllm_config = engine_client.vllm_config - - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - if args.disable_log_requests: - request_logger = None - else: - request_logger = RequestLogger(max_log_len=args.max_log_len) - - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) for name in served_model_names - ] - - state.engine_client = engine_client - state.log_stats = not args.disable_log_stats - state.vllm_config = vllm_config - model_config = vllm_config.model_config - - default_mm_loras = ( - vllm_config.lora_config.default_mm_loras - if vllm_config.lora_config is not None - else {} - ) - lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) - - resolved_chat_template = await process_chat_template( - args.chat_template, engine_client, model_config - ) - - state.openai_serving_models = OpenAIServingModels( - engine_client=engine_client, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - ) - await state.openai_serving_models.init_static_loras() - state.anthropic_serving_messages = AnthropicServingMessages( - engine_client, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - reasoning_parser=args.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) - - -def setup_server(args): - """Validate API server args, set up signal handler, create socket - ready to serve.""" - - logger.info("vLLM API server version %s", VLLM_VERSION) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - - validate_api_server_args(args) - - # workaround to make sure that we bind the port before the engine is set up. - # This avoids race conditions with ray. - # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) - - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active - set_ulimit() - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - addr, port = sock_addr - is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" - listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" - - return listen_address, sock - - -async def run_server(args, **uvicorn_kwargs) -> None: - """Run a single-worker API server.""" - listen_address, sock = setup_server(args) - await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) - - -def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) - app.include_router(router) - app.root_path = args.root_path - - app.add_middleware( - CORSMiddleware, - allow_origins=args.allowed_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allowed_methods, - allow_headers=args.allowed_headers, - ) - - return app - - -async def run_server_worker( - listen_address, sock, args, client_config=None, **uvicorn_kwargs -) -> None: - """Run a single API server worker.""" - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - - server_index = client_config.get("client_index", 0) if client_config else 0 - - # Load logging config for uvicorn if specified - log_config = load_log_config(args.log_config_file) - if log_config is not None: - uvicorn_kwargs["log_config"] = log_config - - async with build_async_engine_client( - args, - client_config=client_config, - ) as engine_client: - app = build_app(args) - - await init_app_state(engine_client, app.state, args) - - logger.info("Starting vLLM API server %d on %s", server_index, listen_address) - shutdown_task = await serve_http( - app, - sock=sock, - enable_ssl_refresh=args.enable_ssl_refresh, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - # NOTE: When the 'disable_uvicorn_access_log' value is True, - # no access log will be output. - access_log=not args.disable_uvicorn_access_log, - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - # NB: Await server shutdown only after the backend context is exited - try: - await shutdown_task - finally: - sock.close() - - -if __name__ == "__main__": - # NOTE(simon): - # This section should be in sync with vllm/entrypoints/cli/main.py for CLI - # entrypoints. - cli_env_setup() - parser = FlexibleArgumentParser( - description="vLLM Anthropic-Compatible RESTful API server." - ) - parser = make_arg_parser(parser) - args = parser.parse_args() - validate_parsed_serve_args(args) - - uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8fa71855f8f66..22b5584749ae7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,6 +41,13 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import Device, EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicError, + AnthropicErrorResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -308,6 +315,10 @@ def responses(request: Request) -> OpenAIServingResponses | None: return request.app.state.openai_serving_responses +def messages(request: Request) -> AnthropicServingMessages: + return request.app.state.anthropic_serving_messages + + def chat(request: Request) -> OpenAIServingChat | None: return request.app.state.openai_serving_chat @@ -591,6 +602,63 @@ async def cancel_responses(response_id: str, raw_request: Request): return JSONResponse(content=response.model_dump()) +@router.post( + "/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): + def translate_error_response(response: ErrorResponse) -> JSONResponse: + anthropic_error = AnthropicErrorResponse( + error=AnthropicError( + type=response.error.type, + message=response.error.message, + ) + ) + return JSONResponse( + status_code=response.error.code, content=anthropic_error.model_dump() + ) + + handler = messages(raw_request) + if handler is None: + error = base(raw_request).create_error_response( + message="The model does not support Messages API" + ) + return translate_error_response(error) + + try: + generator = await handler.create_messages(request, raw_request) + except Exception as e: + logger.exception("Error in create_messages: %s", e) + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=AnthropicErrorResponse( + error=AnthropicError( + type="internal_error", + message=str(e), + ) + ).model_dump(), + ) + + if isinstance(generator, ErrorResponse): + return translate_error_response(generator) + + elif isinstance(generator, AnthropicMessagesResponse): + logger.debug( + "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) + ) + return JSONResponse(content=generator.model_dump(exclude_none=True)) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -1817,6 +1885,24 @@ async def init_app_state( if "transcription" in supported_tasks else None ) + state.anthropic_serving_messages = ( + AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "generate" in supported_tasks + else None + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 From 685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec Mon Sep 17 00:00:00 2001 From: Yue Zhang <81500899+KevinCheung2259@users.noreply.github.com> Date: Sun, 2 Nov 2025 05:08:56 +0800 Subject: [PATCH 008/231] [KV offload] Offloading connector async scheduling support (#27648) Signed-off-by: KevinCheung2259 <2651309292@qq.com> Co-authored-by: Nick Hill --- .../kv_transfer/kv_connector/v1/offloading_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 19344e5784c23..7567c7fae5789 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -274,8 +274,8 @@ class OffloadingConnectorScheduler: if num_new_blocks <= 0: continue - num_gpu_blocks = num_blocks * self.block_size_factor - assert len(req.block_hashes) >= num_gpu_blocks + # NOTE: In async scheduling, placeholders may temporarily make + # len(req.block_hashes) < num_blocks * self.block_size_factor. new_block_hashes = self._get_block_hashes( req, start_idx=start_block_idx, end_idx=num_blocks From 758ea2e980a1eeacec6097bfd98bd0a7c8fb864a Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Sat, 1 Nov 2025 23:45:02 -0400 Subject: [PATCH 009/231] [CI/Build] Fix flaky test_transcription_validation.py::test_basic_audio_gemma (#27924) Signed-off-by: Ben Browning --- tests/entrypoints/openai/test_transcription_validation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 6ef932392d095..f6133d4387b26 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -72,7 +72,9 @@ async def test_basic_audio_gemma(foscolo): model_name = "google/gemma-3n-E2B-it" server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, max_wait_seconds=480 + ) as remote_server: client = remote_server.get_async_client() transcription = await client.audio.transcriptions.create( model=model_name, From 853a8eb53b89f9f3468ab553e86a964cb4e6cd1e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 2 Nov 2025 13:06:05 +0800 Subject: [PATCH 010/231] [Bugfix] Fix Qwen Omni audio inference (#27920) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_5_omni_thinker.py | 9 ++------- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 3 --- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 677d34dea39b3..7e970ebbe2bbc 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -130,6 +130,8 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): TensorShape("nmb", "tsl", dynamic_dims={"tsl"}), ] + audio_feature_lengths: Annotated[torch.Tensor, TensorShape("na")] + feature_attention_mask: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("na", "msl", dynamic_dims={"msl"}), @@ -732,13 +734,6 @@ class Qwen2_5OmniConditionalGenerationMixin: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] - if audio_feature_lengths.shape[0] == 1: - audio_feature_lengths = audio_feature_lengths.squeeze(0) - elif audio_feature_lengths.shape[1] == 1: - audio_feature_lengths = audio_feature_lengths.squeeze(1) - else: - raise AssertionError(audio_feature_lengths.shape) - audio_feat_lengths, audio_output_lengths = ( self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) ) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index efcd003fbbda7..f20e679027214 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -99,7 +99,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, - flatten_bn, maybe_prefix, ) from .vision import ( @@ -1065,8 +1064,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] - audio_feature_lengths = flatten_bn(audio_feature_lengths, concat=True) - audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( audio_feature_lengths ) From 73444b7b5623f5bc569277c8c7dc809843312d11 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Sun, 2 Nov 2025 09:48:33 +0100 Subject: [PATCH 011/231] Performance fix MistralTokenizer: cache special ids and tokens (#27925) Signed-off-by: Julien Denize Co-authored-by: Patrick von Platen --- vllm/transformers_utils/tokenizers/mistral.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 6f710bf23360f..7033523224c51 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase): # Sort the dict for convenience self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) + # Cache special tokens for faster access. + self._special_token_ids = self._get_special_token_ids() + self._special_token_ids_set = set(self._special_token_ids) + self._special_tokens = self._get_special_tokens(self._special_token_ids) + self._special_tokens_set = set(self._special_tokens) + # Vocab sorted by token id. self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 @@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase): ) ) - # the following attributes are set to fit vLLM's design and are used - # by the structured output backends. - @property - def all_special_tokens_extended(self) -> list[str]: - return self.all_special_tokens - - @property - def all_special_tokens(self) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy - - return [ - self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) - for i in self.all_special_ids - ] - - @property - def all_special_ids(self) -> list[int]: + def _get_special_token_ids(self) -> list[int]: from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase): raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") return sorted(special_ids) + def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + return [ + self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) + for i in all_special_ids + ] + + # the following attributes are set to fit vLLM's design and are used + # by the structured output backends. + @property + def all_special_tokens_extended(self) -> list[str]: + return self.all_special_tokens + + @property + def all_special_tokens(self) -> list[str]: + return self._special_tokens + + @property + def all_special_ids(self) -> list[int]: + return self._special_token_ids + @property def bos_token_id(self) -> int: return self.tokenizer.bos_id @@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase): raise NotImplementedError() def _is_special_token_id(self, token_id: int) -> bool: - from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer, - ) - from mistral_common.tokens.tokenizers.tekken import Tekkenizer - - if self.is_spm: - assert isinstance(self.tokenizer, SentencePieceTokenizer), type( - self.tokenizer - ) - return token_id in self.tokenizer._control_tokens - if self.is_tekken: - assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) - return token_id < self.tokenizer.num_special_tokens - else: - raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + return token_id in self._special_token_ids_set def __len__(self) -> int: return self.vocab_size @@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase): tokens = [ t for t in tokens - if (t in to_decode_special_tokens or t not in self.all_special_tokens) + if (t in to_decode_special_tokens or t not in self._special_tokens_set) ] if any(isinstance(t, bytes) for t in tokens): @@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase): # We filtered unwanted special tokens so we can decode the rest. tokens = [ self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) - if token_id not in self.all_special_ids + if token_id not in self._special_token_ids_set else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) for token_id in ids_kept ] From 00b31a36a2d0de6d197a473280b2304d482714af Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Sun, 2 Nov 2025 14:16:23 +0200 Subject: [PATCH 012/231] [V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377) Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> --- csrc/mamba/mamba_ssm/selective_scan.h | 8 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 134 +++++++++++++++--- csrc/ops.h | 24 ++-- csrc/torch_bindings.cpp | 6 +- tests/kernels/mamba/test_mamba_ssm.py | 15 ++ .../models/language/generation/test_hybrid.py | 34 ++--- vllm/_custom_ops.py | 8 ++ vllm/config/model.py | 6 + .../layers/mamba/mamba_mixer.py | 91 ++++++++---- .../layers/mamba/ops/mamba_ssm.py | 24 +++- vllm/model_executor/models/config.py | 2 +- vllm/model_executor/models/jamba.py | 21 ++- vllm/model_executor/models/mamba.py | 9 +- vllm/v1/attention/backends/mamba1_attn.py | 111 ++++++++++++--- vllm/v1/attention/backends/mamba2_attn.py | 40 +----- vllm/v1/attention/backends/mamba_attn.py | 62 +++++++- 16 files changed, 442 insertions(+), 153 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 13c6178941cf8..7d22dd8b84a39 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -24,6 +24,8 @@ struct SSMParamsBase { int64_t pad_slot_id; bool delta_softplus; + bool cache_enabled; + int block_size; index_t A_d_stride; index_t A_dstate_stride; @@ -46,8 +48,9 @@ struct SSMParamsBase { index_t out_z_batch_stride; index_t out_z_d_stride; index_t ssm_states_batch_stride; - index_t ssm_states_dim_stride; + index_t ssm_states_dim_stride; index_t ssm_states_dstate_stride; + index_t cache_indices_stride; // Common data pointers. void *__restrict__ A_ptr; @@ -66,6 +69,9 @@ struct SSMParamsBase { void *__restrict__ cache_indices_ptr; void *__restrict__ has_initial_state_ptr; + void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write + void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write + void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use }; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d534e138d26d6..fb2a2e5789999 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); - const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; // cache_index == params.pad_slot_id is defined as padding, so we exit early if (cache_index == params.pad_slot_id){ return; @@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - typename Ktraits::state_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + - cache_index * params.ssm_states_batch_stride + - dim_id * kNRows * params.ssm_states_dim_stride; + + typename Ktraits::state_t *ssm_states; + if (params.cache_enabled) { + // APC mode: ssm_states points to the base, we'll use absolute cache slots later + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + dim_id * kNRows * params.ssm_states_dim_stride; + } else { + // Non-APC mode: offset by cache_index as before + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; + } float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // } constexpr int kChunkSize = kNThreads * kNItems; - const int n_chunks = (seqlen + 2048 - 1) / 2048; + + // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility + const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048; + const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size; + + const int* batch_cache_indices = cache_indices != nullptr ? + cache_indices + batch_id * params.cache_indices_stride : nullptr; + const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_first_scheduled_token_ptr) : nullptr; + const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_last_scheduled_token_ptr) : nullptr; + const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ? + reinterpret_cast(params.initial_state_idx_ptr) : nullptr; + + const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index; + for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; @@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1)); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int i = 0; i < kNItems; ++i) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } // Initialize running total - - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); + scan_t running_prefix; + if (chunk > 0) { + running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE]; + } else { + // Load initial state + if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) { + size_t state_offset = load_cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + running_prefix = make_float2(1.0, float(ssm_states[state_offset])); + } else if (has_initial_state) { + // Non-APC mode: load from current batch position + running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride])); + } else { + // No initial state + running_prefix = make_float2(1.0, 0.0); + } + } SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // There's a syncthreads in the scan op, so we don't need to sync here. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - if (chunk == n_chunks - 1) { + smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix; + + // Store state at the end of each chunk when cache is enabled + if (params.cache_enabled && batch_cache_indices != nullptr) { + + size_t cache_slot; + if (chunk == n_chunks - 1) { + cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]]; + } else { + cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk]; + } + + size_t state_offset = cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + + ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y); + } else if (!params.cache_enabled && chunk == n_chunks - 1) { + // Non-APC mode: store only final state at current batch position ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y); } } @@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } } - input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); @@ -346,7 +401,9 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM - if (params.seqlen <= 128) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 128) { selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 256) { selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream); @@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #else - if (params.seqlen <= 256) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 256) { selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream); @@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const std::optional& D, const std::optional& delta_bias, const torch::Tensor ssm_states, - bool has_z, + bool has_z, bool delta_softplus, const std::optional& query_start_loc, const std::optional& cache_indices, const std::optional& has_initial_state, bool varlen, - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + // Set cache parameters - cache is enabled if we have direct cache writing params + params.cache_enabled = block_idx_first_scheduled_token.has_value(); + params.block_size = static_cast(block_size); + + // Set direct cache writing pointers + params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr; + params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr; + params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); @@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(0); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; + } else{ if (!is_variable_B) { @@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(1); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; } } @@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, auto cache_indices_ = cache_indices.value(); TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); TORCH_CHECK(cache_indices_.is_cuda()); - CHECK_SHAPE(cache_indices_, batch_size); + + // cache_indices can be either 1D (batch_size,) for non-APC mode + // or 2D (batch_size, max_positions) for APC mode + const bool is_apc_mode = block_idx_first_scheduled_token.has_value(); + if (is_apc_mode) { + TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode"); + TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size"); + } else { + CHECK_SHAPE(cache_indices_, batch_size); + } } @@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, cache_indices, has_initial_state, varlen, - pad_slot_id + pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx ); diff --git a/csrc/ops.h b/csrc/ops.h index 0bed7492f6616..3f5cb799b774c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, std::optional const& scale_ub); -void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, - const torch::Tensor& A, const torch::Tensor& B, - const torch::Tensor& C, - const std::optional& D_, - const std::optional& z_, - const std::optional& delta_bias_, - bool delta_softplus, - const std::optional& query_start_loc, - const std::optional& cache_indices, - const std::optional& has_initial_state, - const torch::Tensor& ssm_states, int64_t pad_slot_id); +void selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, bool delta_softplus, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, + const std::optional& block_idx_first_scheduled_token, + const std::optional& block_idx_last_scheduled_token, + const std::optional& initial_state_idx); torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8f091a429fbef..9c0f524dcab11 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -611,7 +611,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? cache_indices," "Tensor? has_initial_state," "Tensor! ssm_states," - "int pad_slot_id) -> ()"); + "int pad_slot_id," + "int block_size," + "Tensor? block_idx_first_scheduled_token," + "Tensor? block_idx_last_scheduled_token," + "Tensor? initial_state_idx) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); // Hadamard transforms diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index c59fc7af0c897..98edc959957d0 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -179,6 +179,10 @@ def selective_scan_opcheck_fn( has_initial_state=None, ssm_states=None, pad_slot_id=PAD_SLOT_ID, + block_size=2048, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). @@ -223,6 +227,10 @@ def selective_scan_opcheck_fn( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ), test_utils=["test_schema", "test_faketensor"], ) @@ -338,6 +346,11 @@ def test_selective_scan( has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) if c > 0 else None, + pad_slot_id=PAD_SLOT_ID, + block_size=2048, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ) outs.append(out) if len(outs) > 1: @@ -372,6 +385,7 @@ def test_selective_scan( delta_bias=delta_bias, delta_softplus=delta_softplus, ssm_states=state, + block_size=2048, ) @@ -586,6 +600,7 @@ def test_selective_scan_varlen( padded_state_indices, has_initial_state, prev_state, + block_size=2048, ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index fd2df329f17f9..681b380e6a155 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model # meaning that it will be used in all tests in this file # The rest of the models will only be tested by test_models +APC_MULTIPLY_BY = 300 + SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", @@ -380,7 +382,7 @@ def _get_vLLM_output( return outs, vllm_model -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -410,10 +412,8 @@ def test_apc_single_prompt( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * example_prompts[0]] + generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -446,7 +446,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. This custom prompt is used, as it causes the most issues - generated_prompts = ["The president of the United States is " * MULTIPLE] + generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. This custom prompt is used, as it causes the most issues prompt_text = "The president of the United States is " prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] - generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + generated_prompts = [ + prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets + ] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9110b0573fc92..61cf54fcfa39a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1719,6 +1719,10 @@ def selective_scan_fwd( has_initial_state: torch.Tensor | None, ssm_states: torch.Tensor, pad_slot_id: int, + block_size: int = 1024, + block_idx_first_scheduled_token: torch.Tensor | None = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, ): torch.ops._C.selective_scan_fwd( u, @@ -1735,6 +1739,10 @@ def selective_scan_fwd( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 082f90653f5af..2e80df4311035 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1483,6 +1483,12 @@ class ModelConfig: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) + + # Since Mamba1 does not have a chunk notion + # we use a default chunk size of 1024. + if chunk_size is None: + chunk_size = 2048 + return chunk_size def get_multimodal_config(self) -> MultiModalConfig: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a9a0c216474bc..b6345b8af7f0a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size + prefix_caching_enabled = self.cache_config.enable_prefix_caching + if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - mamba1_metadata = attn_metadata - assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) - query_start_loc = mamba1_metadata.query_start_loc - state_indices_tensor = mamba1_metadata.state_indices_tensor + assert isinstance(attn_metadata, Mamba1AttentionMetadata) + query_start_loc_p = attn_metadata.query_start_loc_p + state_indices_tensor = attn_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_states = mamba1_metadata.has_initial_states - num_padded_decodes = mamba1_metadata.num_padded_decodes + has_initial_states_p = attn_metadata.has_initial_states_p + num_padded_decodes = attn_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp): hidden_states_BC, gate, state_indices_tensor, - query_start_loc, - has_initial_states, num_prefill_tokens, - num_decode_tokens, num_prefills, - num_decodes, num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p @@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp): gate_d = prefill_decode_split.gate_d state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d - query_start_loc_p = prefill_decode_split.query_start_loc_p - has_initial_states_p = prefill_decode_split.has_initial_states_p + + if prefix_caching_enabled: + block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( + torch.split( + attn_metadata.block_idx_last_computed_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( + torch.split( + attn_metadata.block_idx_last_scheduled_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + else: + block_idx_last_computed_token_d = None + block_idx_last_computed_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_scheduled_token_p = None + block_idx_first_scheduled_token_p = None + num_computed_tokens_p = None ssm_outputs = [] @@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp): has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, query_start_loc=query_start_loc_p, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, + num_computed_tokens=num_computed_tokens_p, + block_size_to_align=mamba_block_size, ) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( @@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp): cache_indices=state_indices_tensor_p, has_initial_state=has_initial_states_p, query_start_loc=query_start_loc_p, + block_size=mamba_block_size, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, ) ssm_outputs.append(scan_out_p) if has_decode: + if prefix_caching_enabled: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) + else: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), @@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp): self.conv1d.bias, self.activation, conv_state_indices=state_indices_tensor_d, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, + initial_state_idx=block_idx_last_computed_token_d, ).transpose(0, 1) # 3. State Space Model sequence transformation. @@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp): gate_d.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, out=scan_outputs_d, ) scan_outputs_d = scan_outputs_d.transpose(0, 1) @@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple): gate_d: torch.Tensor state_indices_tensor_p: torch.Tensor state_indices_tensor_d: torch.Tensor - query_start_loc_p: torch.Tensor | None - has_initial_states_p: torch.Tensor | None def split_batch_to_prefill_and_decode( hidden_states_BC: torch.Tensor, gate: torch.Tensor, state_indices_tensor: torch.Tensor, - query_start_loc: torch.Tensor, - has_initial_states: torch.Tensor | None, num_prefill_tokens: int, - num_decode_tokens: int, num_prefills: int, - num_decodes: int, num_padded_decodes: int, ) -> PrefillDecodeSplit: num_actual_tokens = num_prefill_tokens + num_padded_decodes @@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode( [num_padded_decodes, num_prefills], dim=0, ) - query_start_loc_p = ( - query_start_loc[-num_prefills - 1 :] - num_padded_decodes - if num_prefills > 0 - else None - ) - has_initial_states_p = ( - has_initial_states[-num_prefills:] - if (has_initial_states is not None and num_prefills > 0) - else None - ) return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode( gate_d=gate_d, state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, - query_start_loc_p=query_start_loc_p, - has_initial_states_p=has_initial_states_p, ) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 8722eb9a7b22f..53fd5d5458b09 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -375,6 +375,10 @@ def selective_scan_fn( cache_indices=None, has_initial_state=None, pad_slot_id=PAD_SLOT_ID, + block_size=1024, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) @@ -397,7 +401,10 @@ def selective_scan_fn( x.shape=(dim,17) cache_indices: (batch) int32 A tensor with each cell is a correspondent - input and output ssm_state index + input and output ssm_state indices + - Without APC: (batch,) - single state index per batch item + - With APC: (batch, max_positions) - cache block indices for read/write + Each non-zero value indicates a cache block to load from and/or write to. has_initial_state: (batch) bool A tensor populated with ones and zeros, indicate if the ssm_state at the corresponding index should be @@ -408,6 +415,17 @@ def selective_scan_fn( that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 + block_size: int + The block size to align the cached states to + block_idx_first_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the first + cache block to be filled is located. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the last cache block + to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into cache_indices, where the cache block + containing the initial state is located. returns output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement @@ -448,6 +466,10 @@ def selective_scan_fn( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ) if z is None: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7150977e9266b..5dda2ec97875f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -299,7 +299,7 @@ class MambaModelConfig(VerifyAndUpdateConfig): if model_config.supports_mamba_prefix_caching: logger.info( "Warning: Prefix caching is currently enabled. " - "Its support for Mamba2 layers is experimental. " + "Its support for Mamba layers is experimental. " "Please report any issues you may observe." ) else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index f8a87cf6965f8..ba95021b0b542 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -38,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsMambaPrefixCaching, + SupportsPP, +) from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -454,7 +460,14 @@ class JambaModel(nn.Module): return loaded_params -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): +class JambaForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsMambaPrefixCaching, +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"}, ) @@ -477,12 +490,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHyb def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Jamba currently does not support prefix caching" - ) super().__init__() self.config = config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fb145289fbfe9..f684203f6d35e 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( HasInnerState, IsAttentionFree, + SupportsMambaPrefixCaching, SupportsPP, ) from vllm.sequence import IntermediateTensors @@ -193,15 +194,13 @@ class MambaModel(nn.Module): return loaded_params -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): +class MambaForCausalLM( + nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching +): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Mamba does not support prefix caching" - ) super().__init__() self.config = config diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 30c63e0ded8e7..909af09be255a 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -7,11 +7,13 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, split_decodes_and_prefills, ) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -22,32 +24,41 @@ class Mamba1AttentionBackend(AttentionBackend): @dataclass class Mamba1AttentionMetadata: - query_start_loc: torch.Tensor - context_lens_tensor: torch.Tensor + query_start_loc_p: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: torch.Tensor | None + has_initial_states_p: torch.Tensor | None num_prefills: int num_prefill_tokens: int num_decodes: int num_decode_tokens: int num_padded_decodes: int + block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] + block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] + block_idx_last_computed_token: torch.Tensor # shape: [batch,] + num_computed_tokens_p: torch.Tensor # shape: [batch,] + class Mamba1AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] ): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert isinstance(kv_cache_spec, MambaSpec) + def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> Mamba1AttentionMetadata: - query_start_loc = common_attn_metadata.query_start_loc - - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - query_start_loc.device - ) + num_reqs = common_attn_metadata.num_reqs num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -55,32 +66,100 @@ class Mamba1AttentionMetadataBuilder( ) ) - has_initial_states = None + has_initial_states_p = None + query_start_loc_p = None padded_decodes = num_decodes + num_computed_tokens, num_computed_tokens_p = None, None + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None + + # TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here. + # We should consolidate this code + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + mamba_block_size = self.kv_cache_spec.block_size + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) = self._compute_prefix_caching_block_indices( + common_attn_metadata, mamba_block_size + ) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + block_idx_last_scheduled_token = None + block_idx_last_computed_token = None if num_prefills > 0: - has_initial_states = context_lens_tensor > 0 + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + has_initial_states_cpu = ( + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to( + common_attn_metadata.query_start_loc.device + ) + + if self.vllm_config.cache_config.enable_prefix_caching: + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ + num_reqs - num_prefills : num_reqs + ] + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ + num_reqs - num_prefills : num_reqs + ] + elif ( num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.full_cuda_graph ): - state_indices_for_decode = state_indices_tensor[:num_decodes] padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( - state_indices_for_decode, non_blocking=True + state_indices_tensor, non_blocking=True ) state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + if self.vllm_config.cache_config.enable_prefix_caching: + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :padded_decodes + ] + block_idx_last_scheduled_token[num_decodes:] = 0 + + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :padded_decodes + ] + block_idx_last_computed_token[num_decodes:] = 0 + return Mamba1AttentionMetadata( - query_start_loc=query_start_loc, - context_lens_tensor=context_lens_tensor, - has_initial_states=has_initial_states, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, state_indices_tensor=state_indices_tensor, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_padded_decodes=padded_decodes, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f9d2426eaf632..4bc1057333a50 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -147,27 +147,6 @@ class Mamba2AttentionMetadataBuilder( assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models" ) - if self.vllm_config.cache_config.enable_prefix_caching: - self.state_indices_tensor = torch.empty( - ( - self.decode_cudagraph_max_bs, - cdiv( - vllm_config.model_config.max_model_len, kv_cache_spec.block_size - ), - ), - dtype=torch.int32, - device=device, - ) - self.block_idx_last_scheduled_token = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) - self.block_idx_last_computed_token = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) def build( self, @@ -202,20 +181,13 @@ class Mamba2AttentionMetadataBuilder( num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( self.device ) - # Block index of the last computed token - block_idx_last_computed_token = ( - cdiv(num_computed_tokens, mamba_block_size) - 1 + ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) = self._compute_prefix_caching_block_indices( + common_attn_metadata, mamba_block_size ) - # which is <= block index for the first scheduled token - block_idx_first_scheduled_token = ( - cdiv(num_computed_tokens + 1, mamba_block_size) - 1 - ) - # which is <= block index of the last scheduled token - block_idx_last_scheduled_token = ( - cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 - ) - # -1 in case it's non-computed and causes later issues with indexing - block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 52f26a9e61cab..49d7d6c31b9a0 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,6 +7,7 @@ from typing import ClassVar, TypeVar import torch from vllm.config import VllmConfig +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -38,11 +39,35 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): self.vllm_config.scheduler_config.max_num_seqs, self.compilation_config.max_cudagraph_capture_size, ) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) + + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor = torch.empty( + ( + self.decode_cudagraph_max_bs, + cdiv( + self.vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size, + ), + ), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + else: + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -61,3 +86,30 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): m.max_query_len = 1 # decode-only return self.build(0, m) + + def _compute_prefix_caching_block_indices( + self, + common_attn_metadata: CommonAttentionMetadata, + mamba_block_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + # Block index of the last computed token + block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 + # which is <= block index for the first scheduled token + block_idx_first_scheduled_token = ( + cdiv(num_computed_tokens + 1, mamba_block_size) - 1 + ) + # which is <= block index of the last scheduled token + block_idx_last_scheduled_token = ( + cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 + ) + # -1 in case it's non-computed and causes later issues with indexing + block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + + return ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) From 6c317a656eb09a641d85be05aa8498ff160bf0c1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 2 Nov 2025 21:42:38 +0800 Subject: [PATCH 013/231] [Misc] Provide Siglip2 chat template (#27939) Signed-off-by: DarkLight1337 --- vllm/transformers_utils/chat_templates/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 3bdbe1d0a67b6..fe84b6c152eef 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -40,6 +40,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, "siglip": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "siglip2": CHAT_TEMPLATES_DIR / "template_basic.jinja", } From 0ce743f4e1879ffa250e471f6894633ef125418e Mon Sep 17 00:00:00 2001 From: Vensen Date: Mon, 3 Nov 2025 00:24:01 +0800 Subject: [PATCH 014/231] Fix(llm): Abort orphaned requests when llm.chat() batch fails Fixes #26081 (#27420) Signed-off-by: vensenmu --- tests/entrypoints/llm/test_chat.py | 53 ++++++++++++++++++++++++++++++ vllm/entrypoints/llm.py | 36 ++++++++++++-------- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index b2a958a992a62..a9698632b82e0 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -6,6 +6,7 @@ import pytest from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.sampling_params import SamplingParams from ..openai.test_vision import TEST_IMAGE_ASSETS @@ -23,6 +24,29 @@ def text_llm(): cleanup_dist_env_and_memory() +@pytest.fixture(scope="function") +def llm_for_failure_test(): + """ + Fixture for testing issue #26081. + Uses a small max_model_len to easily trigger length errors. + """ + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + seed=0, + max_model_len=128, + disable_log_stats=True, + ) + + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ @@ -157,3 +181,32 @@ def test_chat_extra_kwargs(thinking_llm, enable_thinking): else: # The chat template includes dummy thinking process assert think_id in prompt_token_ids + + +def test_chat_batch_failure_cleanup(llm_for_failure_test): + """ + Tests that if a batch call to llm.chat() fails mid-way + (e.g., due to one invalid prompt), the requests that + were already enqueued are properly aborted and do not + pollute the queue for subsequent calls. + (Fixes Issue #26081) + """ + llm = llm_for_failure_test + valid_msg = [{"role": "user", "content": "Hello"}] + long_text = "This is a very long text to test the error " * 50 + invalid_msg = [{"role": "user", "content": long_text}] + batch_1 = [ + valid_msg, + valid_msg, + invalid_msg, + ] + batch_2 = [ + valid_msg, + valid_msg, + ] + sampling_params = SamplingParams(temperature=0, max_tokens=10) + with pytest.raises(ValueError, match="longer than the maximum model length"): + llm.chat(batch_1, sampling_params=sampling_params) + outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) + assert len(outputs_2) == len(batch_2) + assert llm.llm_engine.get_num_unfinished_requests() == 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b0b996ab2fec5..22fe2ae9280aa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1588,20 +1588,27 @@ class LLM: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") - for i, prompt in enumerate(it): - if isinstance(prompt, dict): - self._validate_mm_data_and_uuids( - prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") - ) + added_request_ids: list[str] = [] - self._add_request( - prompt, - params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] - if isinstance(lora_request, Sequence) - else lora_request, - priority=priority[i] if priority else 0, - ) + try: + for i, prompt in enumerate(it): + if isinstance(prompt, dict): + self._validate_mm_data_and_uuids( + prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") + ) + request_id = self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, + priority=priority[i] if priority else 0, + ) + added_request_ids.append(request_id) + except Exception as e: + if added_request_ids: + self.llm_engine.abort_request(added_request_ids) + raise e def _validate_mm_data_and_uuids( self, @@ -1684,7 +1691,7 @@ class LLM: params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, priority: int = 0, - ) -> None: + ) -> str: prompt_text, _, _ = get_prompt_components(prompt) request_id = str(next(self.request_counter)) @@ -1705,6 +1712,7 @@ class LLM: priority=priority, prompt_text=prompt_text, ) + return request_id def _run_engine( self, *, use_tqdm: bool | Callable[..., tqdm] = True From 1bf43ae35d7f6a83cc2025b8c0a2332456f4afe9 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Sun, 2 Nov 2025 18:08:08 -0800 Subject: [PATCH 015/231] [BugFix][LoRA] use adapter_id instead of id field of lora_request (#27728) Signed-off-by: Biswa Panda --- tests/v1/core/test_prefix_caching.py | 63 +++++++++++++++++++++++++++- vllm/v1/core/block_pool.py | 4 +- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 837a513cb75e1..2291f363731f2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -9,7 +9,8 @@ import pytest import torch import vllm.v1.core.kv_cache_utils as kv_cache_utils -from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved +from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved, BlockStored +from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, @@ -59,6 +60,7 @@ def make_request( mm_hashes: list[str] | None = None, prompt_logprobs: int | None = None, cache_salt: str | None = None, + lora_request: LoRARequest | None = None, ): mm_features = [] if mm_positions is not None: @@ -79,7 +81,7 @@ def make_request( sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), pooling_params=None, eos_token_id=100, - lora_request=None, + lora_request=lora_request, cache_salt=cache_salt, block_hasher=get_request_block_hasher(block_size, hash_fn), ) @@ -1337,6 +1339,63 @@ def test_kv_cache_events(blocks_to_cache: int): assert len(manager.block_pool.cached_block_hash_to_block) == 0 +@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) +def test_kv_cache_events_with_lora(blocks_to_cache: int): + """Test BlockStored events contain correct lora_id when using LoRA requests.""" + block_size = 16 + num_blocks = blocks_to_cache + 1 + + # Create KVCacheManager with events enabled + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks), + max_model_len=8192, + enable_caching=True, + enable_kv_cache_events=True, + ) + + # Test with LoRA request + lora_request = LoRARequest( + lora_name="test_lora", lora_int_id=42, lora_path="/test/path" + ) + + num_tokens = block_size * blocks_to_cache + req_with_lora = make_request( + "lora_req", + list(range(num_tokens)), + block_size, + sha256, + lora_request=lora_request, + ) + + # Allocate slots and get events + _ = manager.allocate_slots(req_with_lora, num_tokens) + events = manager.take_events() + + # Verify BlockStored event contains correct lora_id + block_stored_event = events[-1] + assert isinstance(block_stored_event, BlockStored) + assert block_stored_event.lora_id == 42 # Should match lora_request.adapter_id + assert len(block_stored_event.block_hashes) == blocks_to_cache + assert block_stored_event.block_size == block_size + + # Clean up + manager.free(req_with_lora) + + # Test without LoRA request (should have lora_id=None) + req_without_lora = make_request( + "no_lora_req", list(range(num_tokens)), block_size, sha256 + ) + + _ = manager.allocate_slots(req_without_lora, num_tokens) + events = manager.take_events() + + block_stored_event = events[-1] + assert isinstance(block_stored_event, BlockStored) + assert block_stored_event.lora_id is None # Should be None when no LoRA request + assert len(block_stored_event.block_hashes) == blocks_to_cache + assert block_stored_event.block_size == block_size + + def test_eagle_enabled_removes_last_block(): """Verify Eagle does NOT remove blocks when request length is divisible by block size.""" diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 15c06a0b107d8..55710ad5cc693 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -259,7 +259,9 @@ class BlockPool: num_cached_blocks * block_size : num_full_blocks * block_size ], block_size=block_size, - lora_id=request.lora_request.id if request.lora_request else None, + lora_id=request.lora_request.adapter_id + if request.lora_request + else None, medium=MEDIUM_GPU, ) ) From 470ad118b6238e66094c9a508dea0aaaaf864093 Mon Sep 17 00:00:00 2001 From: Sungyoon Jeong <157349761+n0gu-furiosa@users.noreply.github.com> Date: Mon, 3 Nov 2025 13:21:18 +0900 Subject: [PATCH 016/231] [Frontend] Align finish_reason when tool is called with OpenAI (#25054) Signed-off-by: Sungyoon Jeong Co-authored-by: Chauncey --- vllm/entrypoints/openai/serving_chat.py | 26 +++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index bb770ecf03383..25979d5502b07 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1170,9 +1170,13 @@ class OpenAIServingChat(OpenAIServing): ) # Send the finish response for each request.n only once + # In OpenAI's API, when a tool is called, the + # finish_reason is: + # "tool_calls" for "auto" or "required" tool calls, + # and "stop" for named tool calls. if ( auto_tools_called - or tools_streamed[i] + or (tools_streamed[i] and not tool_choice_function_name) or (self.use_harmony and harmony_tools_streamed[i]) ): finish_reason_ = "tool_calls" @@ -1523,18 +1527,24 @@ class OpenAIServingChat(OpenAIServing): message = ChatMessage( role=role, reasoning_content=reasoning_content, content=content ) + # In OpenAI's API, when a tool is called, the finish_reason is: + # "tool_calls" for "auto" or "required" tool calls, + # and "stop" for named tool calls. + is_finish_reason_tool_calls = auto_tools_called or ( + request.tool_choice + and request.tool_choice == "required" + and output.finish_reason == "stop" + ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason=( - "tool_calls" - if auto_tools_called - else output.finish_reason - if output.finish_reason - else "stop" - ), + finish_reason="tool_calls" + if is_finish_reason_tool_calls + else output.finish_reason + if output.finish_reason + else "stop", stop_reason=output.stop_reason, token_ids=( as_list(output.token_ids) if request.return_token_ids else None From 18961c5ea62976efc50525b72e40337993c5e4f9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 3 Nov 2025 06:48:03 +0100 Subject: [PATCH 017/231] [Hybrid] Pass kernel block size to builders (#27753) Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/flash_attn.py | 6 +++- vllm/v1/kv_cache_interface.py | 8 ++++- vllm/v1/worker/gpu_model_runner.py | 31 +++++++++++++---- vllm/v1/worker/utils.py | 44 ++++++++++++++---------- 4 files changed, 62 insertions(+), 27 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1eac94940e781..07f9ef173b4e3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend): @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] @classmethod def validate_head_size(cls, head_size: int) -> None: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0f564fdb3b080..7f33eb7e699c7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, replace from math import prod import torch @@ -44,6 +44,12 @@ class KVCacheSpec: """ raise NotImplementedError + def copy_with_new_block_size(self, block_size: int) -> Self: + """ + Create a new KVCacheSpec from self but replacing the block size. + """ + return replace(self, block_size=block_size) + @classmethod def merge(cls, specs: list[Self]) -> Self: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 66a9d72912618..9212221bb6009 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): - attn_group = AttentionGroup.create_with_metadata_builders( + attn_group = AttentionGroup( attn_backend, layer_names, kv_cache_spec, - self.vllm_config, - self.device, kv_cache_group_id, - num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, ) attn_groups.append(attn_group) @@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) + def initialize_metadata_builders( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: + """ + Create the metadata builders for all KV cache groups and attn groups. + """ + for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None, + num_metadata_builders=1 + if not self.parallel_config.enable_dbo + else 2, + ) # Calculate reorder batch threshold (if needed) + # Note (tdoublep): do this *after* constructing builders, + # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() def _check_and_update_cudagraph_mode( @@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # tokens each. kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + # create metadata builders + self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) + # Reinitialize need to after initialize_attn_backend self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) kv_caches = self.initialize_kv_cache_tensors( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 396adbcfb289f..0ca7e81a5c7b8 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch @@ -134,31 +134,37 @@ class MultiModalBudget: @dataclass class AttentionGroup: backend: type[AttentionBackend] - # When ubatching is enabled we will have a metadata builder for each ubatch - # so that if they use internal persistant buffers for cudagraphs, and they - # won't have to worry about conflicting with the other ubatches. - metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec kv_cache_group_id: int + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. + metadata_builders: list[AttentionMetadataBuilder] = field( + default_factory=lambda: [] + ) - @staticmethod - def create_with_metadata_builders( - backend: type[AttentionBackend], - layer_names: list[str], - kv_cache_spec: KVCacheSpec, - vllm_config: VllmConfig, - device: torch.device, - kv_cache_group_id: int, + def create_metadata_builders( + self, + vllm_config, + device, + kernel_block_size: int | None, num_metadata_builders: int = 1, - ) -> "AttentionGroup": - metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) + ): + kv_cache_spec_builder = ( + self.kv_cache_spec.copy_with_new_block_size(kernel_block_size) + if kernel_block_size is not None + else self.kv_cache_spec + ) + self.metadata_builders = [ + self.backend.get_builder_cls()( + kv_cache_spec_builder, + self.layer_names, + vllm_config, + device, + ) for _ in range(num_metadata_builders) ] - return AttentionGroup( - backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id - ) def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id From cec7c288333339028f6fe8e0ac3222e3924da90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Mon, 3 Nov 2025 08:22:46 +0100 Subject: [PATCH 018/231] [Bugfix] Padded Eagle Specdec with Chunked Prefill (#26263) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi Co-authored-by: Benjamin Chislett --- tests/v1/e2e/test_spec_decode.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 45b48e5858934..ea7fcdf3174ec 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -202,9 +202,9 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled"], + ["model_setup", "mm_enabled", "chunked_prefill_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), pytest.param( ( "eagle3", @@ -213,11 +213,12 @@ def test_speculators_model_integration( 1, ), False, + False, marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), ), - ( + pytest.param( ( "eagle", "meta-llama/Llama-3.1-8B-Instruct", @@ -225,7 +226,9 @@ def test_speculators_model_integration( 1, ), False, - ), + True, + marks=large_gpu_mark(min_gb=40), + ), # works on 4x H100 ( ( "eagle3", @@ -234,6 +237,7 @@ def test_speculators_model_integration( 1, ), False, + False, ), pytest.param( ( @@ -243,6 +247,7 @@ def test_speculators_model_integration( 4, ), False, + False, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( @@ -253,6 +258,7 @@ def test_speculators_model_integration( 4, ), True, + True, marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -263,6 +269,7 @@ def test_speculators_model_integration( 1, ), False, + False, ), ], ids=[ @@ -281,6 +288,7 @@ def test_eagle_correctness( sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + chunked_prefill_enabled: bool, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -317,9 +325,13 @@ def test_eagle_correctness( m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup + max_model_len = 2048 + max_num_batched_tokens = max_model_len + if chunked_prefill_enabled: + max_num_batched_tokens = 128 ref_llm = LLM( - model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -334,9 +346,11 @@ def test_eagle_correctness( "method": method, "model": spec_model_name, "num_speculative_tokens": 3, - "max_model_len": 2048, + "max_model_len": max_model_len, }, - max_model_len=2048, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=chunked_prefill_enabled, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 From 7f4bdadb926936a11a88a619f56634061e824798 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Mon, 3 Nov 2025 15:36:59 +0800 Subject: [PATCH 019/231] [XPU]Refine Dockerfile.xpu, avoid oneccl dependency issue (#27964) Signed-off-by: Kunshang Ji --- docker/Dockerfile.xpu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 49ea39cad5128..4e6ef8f5ca13c 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -54,7 +54,7 @@ ENV VLLM_WORKER_MULTIPROC_METHOD=spawn RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ - python3 setup.py install + pip install --no-build-isolation . CMD ["/bin/bash"] @@ -64,9 +64,6 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -RUN --mount=type=cache,target=/root/.cache/pip \ - pip uninstall oneccl oneccl-devel -y - # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils @@ -74,4 +71,7 @@ RUN python3 -m pip install -e tests/vllm_test_utils RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/" +RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y + ENTRYPOINT ["vllm", "serve"] From ba464e6ae24857b2db7c82f4123342b9ab90049e Mon Sep 17 00:00:00 2001 From: Misha Efimov Date: Mon, 3 Nov 2025 03:21:31 -0500 Subject: [PATCH 020/231] Add ORCA endpoint load metrics support (#24905) Signed-off-by: Misha Efimov --- tests/entrypoints/openai/test_orca_metrics.py | 128 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 19 ++- vllm/entrypoints/openai/orca_metrics.py | 120 ++++++++++++++++ 3 files changed, 265 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/openai/test_orca_metrics.py create mode 100644 vllm/entrypoints/openai/orca_metrics.py diff --git a/tests/entrypoints/openai/test_orca_metrics.py b/tests/entrypoints/openai/test_orca_metrics.py new file mode 100644 index 0000000000000..d32cfde07c21e --- /dev/null +++ b/tests/entrypoints/openai/test_orca_metrics.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import openai +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module", params=[True]) +def server(request, monkeypatch_module): + use_v1 = request.param + monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_chat_completion_with_orca_header(server: RemoteOpenAIServer): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] + + client = openai.OpenAI( + api_key="EMPTY", + base_url=f"http://localhost:{server.port}/v1", + default_headers={"endpoint-load-metrics-format": "TEXT"}, + ) + + # 1. Use raw client to get response headers. + raw_client = client.with_raw_response + + # 2. Make the API call using the raw_client + response_with_raw = raw_client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_headers={"endpoint-load-metrics-format": "TEXT"}, + ) + + # 3. Access the raw httpx.Response object + raw_http_response = response_with_raw.http_response + + # 4. Get the headers from the httpx.Response object + response_headers = raw_http_response.headers + + assert "endpoint-load-metrics" in response_headers + + +@pytest.mark.asyncio +async def test_completion_with_orca_header(client: openai.AsyncOpenAI): + # 1. Use raw client to get response headers. + raw_client = client.with_raw_response + + # 2. Make the API call using the raw_client + completion = await raw_client.completions.create( + model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + extra_headers={"endpoint-load-metrics-format": "JSON"}, + ) + + # 3. Access the raw httpx.Response object + raw_http_response = completion.http_response + + # 4. Get the headers from the httpx.Response object + response_headers = raw_http_response.headers + + assert "endpoint-load-metrics" in response_headers + + +@pytest.mark.asyncio +async def test_single_completion(client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + extra_headers={"endpoint-load-metrics-format": "JSON"}, + temperature=0.0, + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 22b5584749ae7..c37aba2776aeb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -51,6 +51,7 @@ from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.orca_metrics import metrics_header from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -128,6 +129,8 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger("vllm.entrypoints.openai.api_server") +ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" + _running_tasks: set[asyncio.Task] = set() @@ -672,6 +675,9 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques @with_cancellation @load_aware_call async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + metrics_header_format = raw_request.headers.get( + ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "" + ) handler = chat(raw_request) if handler is None: return base(raw_request).create_error_response( @@ -689,7 +695,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) + return JSONResponse( + content=generator.model_dump(), + headers=metrics_header(metrics_header_format), + ) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -707,6 +716,9 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re @with_cancellation @load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): + metrics_header_format = raw_request.headers.get( + ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "" + ) handler = completion(raw_request) if handler is None: return base(raw_request).create_error_response( @@ -729,7 +741,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request): content=generator.model_dump(), status_code=generator.error.code ) elif isinstance(generator, CompletionResponse): - return JSONResponse(content=generator.model_dump()) + return JSONResponse( + content=generator.model_dump(), + headers=metrics_header(metrics_header_format), + ) return StreamingResponse(content=generator, media_type="text/event-stream") diff --git a/vllm/entrypoints/openai/orca_metrics.py b/vllm/entrypoints/openai/orca_metrics.py new file mode 100644 index 0000000000000..3808262bf31f2 --- /dev/null +++ b/vllm/entrypoints/openai/orca_metrics.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utility functions that create ORCA endpoint load report response headers. +""" + +import json +from collections.abc import Mapping + +from vllm.logger import init_logger +from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot + +logger = init_logger(__name__) + + +def create_orca_header( + metrics_format: str, named_metrics: list[tuple[str, float]] +) -> Mapping[str, str] | None: + """ + Creates ORCA headers named 'endpoint-load-metrics' in the specified format + and adds custom metrics to named_metrics. + ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0 + ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto + + Parameters: + - metrics_format (str): The format of the header ('TEXT', 'JSON'). + - named_metrics (List[Tuple[str, float]]): List of tuples with metric names + and their corresponding double values. + + Returns: + - Optional[Mapping[str,str]]: A dictionary with header key as + 'endpoint-load-metrics' and values as the ORCA header strings with + format prefix and data in with named_metrics in. + """ + + if metrics_format.lower() not in ["text", "json"]: + logger.warning( + "Warning: `%s` format is not supported in the ORCA response header", + format, + ) + return None + + header = {} + orca_report = { + "named_metrics": { + metric_name: value + for metric_name, value in named_metrics + if isinstance(metric_name, str) and isinstance(value, float) + } + } + # output example: + # endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4 + if metrics_format.lower() == "text": + native_http_header = ", ".join( + [ + f"named_metrics.{metric_name}={value}" + for metric_name, value in named_metrics + if isinstance(metric_name, str) and isinstance(value, float) + ] + ) + header["endpoint-load-metrics"] = f"TEXT {native_http_header}" + + # output example: + # endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4} + elif metrics_format.lower() == "json": + header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}" + + logger.info("Created ORCA header %s", header) + + return header + + +def get_named_metrics_from_prometheus() -> list[tuple[str, float]]: + """ + Collects current metrics from Prometheus and returns some of them + in the form of the `named_metrics` list for `create_orca_header()`. + + Parameters: + - None + + Returns: + - list[tuple[str, float]]: List of tuples of metric names and their values. + """ + named_metrics: list[tuple[str, float]] = [] + # Map from prometheus metric names to ORCA named metrics. + prometheus_to_orca_metrics = { + "vllm:kv_cache_usage_perc": "kv_cache_usage_perc", + "vllm:num_requests_waiting": "num_requests_waiting", + } + metrics = get_metrics_snapshot() + for metric in metrics: + orca_name = prometheus_to_orca_metrics.get(metric.name) + # If this metric is mapped into ORCA, then add it to the report. + # Note: Only Gauge metrics are currently supported. + if orca_name is not None and isinstance(metric, Gauge): + named_metrics.append((str(orca_name), float(metric.value))) + return named_metrics + + +def metrics_header(metrics_format: str) -> Mapping[str, str] | None: + """ + Creates ORCA headers named 'endpoint-load-metrics' in the specified format. + Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`. + + ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0 + ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto + + Parameters: + - metrics_format (str): The format of the header ('TEXT', 'JSON'). + + Returns: + - Optional[Mapping[str,str]]: A dictionary with header key as + 'endpoint-load-metrics' and values as the ORCA header strings with + format prefix and data in with named_metrics in. + """ + if not metrics_format: + return None + # Get named metrics from prometheus. + named_metrics = get_named_metrics_from_prometheus() + return create_orca_header(metrics_format, named_metrics) From 32257297dd4dcb996a0fb4641c2018289d20396b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 3 Nov 2025 16:50:06 +0800 Subject: [PATCH 021/231] [CI/Build] Remove the flaky gpt-oss lora test (#27966) Signed-off-by: Jee Jee Li --- tests/lora/test_gptoss_tp.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index db4b7ca5ef499..711d514a39eb3 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -32,7 +32,6 @@ The Competition_ID of competition_record is the foreign key of Competition_ID of ###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501 EXPECTED_LORA_OUTPUT = [ - "SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;", "SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;", "SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;", "SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;", @@ -41,9 +40,6 @@ EXPECTED_LORA_OUTPUT = [ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: prompts = [ - PROMPT_TEMPLATE.format( - context="What is the average number of working horses of farms with more than 5000 total number of horses?" # noqa: E501 - ), # noqa: E501 PROMPT_TEMPLATE.format( context="Give the average number of working horses on farms with more than 5000 total horses." # noqa: E501 ), # noqa: E501 @@ -67,7 +63,6 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: generated_text = output.outputs[0].text.strip() generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - for i in range(len(EXPECTED_LORA_OUTPUT)): assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) From 40b69e33e796efdc75e774a1c38cc73397ea6e17 Mon Sep 17 00:00:00 2001 From: zhang-prog <69562787+zhang-prog@users.noreply.github.com> Date: Mon, 3 Nov 2025 19:04:22 +0800 Subject: [PATCH 022/231] [Model] Add PaddleOCR-VL Model Support (#27758) Signed-off-by: zhangyue Signed-off-by: Roger Wang Signed-off-by: Isotr0py Signed-off-by: zhangyue66 Co-authored-by: Roger Wang Co-authored-by: Isotr0py --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 27 + .../vision_language_multi_image.py | 22 + tests/models/registry.py | 4 + vllm/model_executor/models/ernie45.py | 10 + vllm/model_executor/models/paddleocr_vl.py | 1407 +++++++++++++++++ vllm/model_executor/models/registry.py | 4 + 7 files changed, 1475 insertions(+) create mode 100644 vllm/model_executor/models/paddleocr_vl.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index fd25647dce54b..21235e305db4b 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -675,6 +675,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | +| `PaddleOCRVLForConditionalGeneration` | Paddle-OCR | T + I+ | `PaddlePaddle/PaddleOCR-VL`, etc. | | | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index c1ea95f8d0644..371cf6309a678 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1242,6 +1242,32 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: ) +# PaddleOCR-VL +def run_paddleocr_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "PaddlePaddle/PaddleOCR-VL" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + placeholder = "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + prompts = [ + (f"<|begin_of_sentence|>User: {question}{placeholder}\nAssistant: ") + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1817,6 +1843,7 @@ model_example_map = { "NVLM_D": run_nvlm_d, "ovis": run_ovis, "ovis2_5": run_ovis2_5, + "paddleocr_vl": run_paddleocr_vl, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 5cb47c15038e8..80c7fc4431229 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -801,6 +801,27 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_paddleocr_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "PaddlePaddle/PaddleOCR-VL" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" * len(image_urls) + prompt = f"<|begin_of_sentence|>User: {question}{placeholders}\nAssistant: " + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -1312,6 +1333,7 @@ model_example_map = { "NVLM_D": load_nvlm_d, "ovis": load_ovis, "ovis2_5": load_ovis2_5, + "paddleocr_vl": load_paddleocr_vl, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "phi4_multimodal": load_phi4_multimodal, diff --git a/tests/models/registry.py b/tests/models/registry.py index 8e1dd4ba91f1d..00fe999805003 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -712,6 +712,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { }, ), "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True), + "PaddleOCRVLForConditionalGeneration": _HfExamplesInfo( + "PaddlePaddle/PaddleOCR-VL", + trust_remote_code=True, + ), "PaliGemmaForConditionalGeneration": _HfExamplesInfo( "google/paligemma-3b-mix-224", extras={"v2": "google/paligemma2-3b-ft-docci-448"}, diff --git a/vllm/model_executor/models/ernie45.py b/vllm/model_executor/models/ernie45.py index b1d26cddcc5eb..c1a4737e1f326 100644 --- a/vllm/model_executor/models/ernie45.py +++ b/vllm/model_executor/models/ernie45.py @@ -23,12 +23,22 @@ # limitations under the License. """Inference-only Erine model compatible with HuggingFace weights.""" +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM from .utils import PPMissingLayer +@support_torch_compile( + # set dynamic_arg_dims to support mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) class Ernie4_5ForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py new file mode 100644 index 0000000000000..377b41a355782 --- /dev/null +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -0,0 +1,1407 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Annotated, Literal + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature, PretrainedConfig +from transformers.activations import GELUActivation +from transformers.modeling_outputs import ( + BaseModelOutputWithPooling, +) +from transformers.utils import torch_int + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) +from vllm.attention.ops.vit_attn_wrappers import ( + vit_flash_attn_wrapper, + vit_xformers_attn_wrapper, +) +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + dispatch_rotary_emb_function, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, +) +from vllm.multimodal.parse import ( + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .ernie45 import Ernie4_5ForCausalLM +from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + is_pp_missing_parameter, + maybe_prefix, +) +from .vision import get_vit_attn_backend + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + + if height < factor: + width = round((width * factor) / height) + height = factor + + if width < factor: + height = round((height * factor) / width) + width = factor + + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = rotary_emb_function(t_, cos, sin).type_as(t) + return output + + +class PaddleOCRVLProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self): + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor, + ) -> int: + if image_processor is None: + image_processor = self.get_image_processor() + + do_resize = True + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, height=image_height) + + grid_t = 1 + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_image_tokens = num_patches // (merge_size**2) + + return num_image_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + image_size = hf_config.vision_config.image_size + return ImageSize(height=image_size, width=image_size) + + +class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + max_image_size = self.info.get_image_size_with_most_features() + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class PaddleOCRVLMultiModalProcessor( + BaseMultiModalProcessor[PaddleOCRVLProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), + ) + num_patches_per_image = processed_outputs["image_grid_thw"].prod(-1) + processed_outputs["pixel_values"] = processed_outputs["pixel_values"].split( + num_patches_per_image.tolist() + ) + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer( + prompt, add_special_tokens=True, return_tensors="pt" + ) + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_id + + def get_replacement(item_idx: int, image_processor): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + image_processor=image_processor, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=partial(get_replacement, image_processor=image_processor), + ), + ] + + +class Projector(nn.Module): + def __init__( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + prefix: str = "", + ): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = ( + self.vision_config.hidden_size + * self.merge_kernel_size[0] + * self.merge_kernel_size[1] + ) + + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) + self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + self.hidden_size, self.text_config.hidden_size, bias=True + ) + + def forward( + self, + image_features: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + m1, m2 = self.merge_kernel_size + if isinstance(image_features, (list, tuple)): + processed_features = list() + for image_feature, image_grid in zip(image_features, image_grid_thw): + image_feature = self.pre_norm(image_feature) + t, h, w = image_grid + + image_feature = rearrange( + image_feature, + "(t h p1 w p2) d -> (t h w) (p1 p2 d)", + t=t, + h=h // m1, + p1=m1, + w=w // m2, + p2=m2, + ) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + dims = image_features.shape[:-1] + dim = image_features.shape[-1] + image_features = image_features.view(np.prod(dims), dim) + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states.view(*dims, -1) + + +class PaddleOCRImagePixelInputs(TensorSchema): + type: Literal["pixel_values"] + pixel_values: Annotated[ + torch.Tensor, + TensorShape("bn", "p", 3, "patch_size", "patch_size", dynamic_dims={"p"}), + ] + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("bn", 3), + ] + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.cache_position_embedding = dict() + self.cache_position_count = dict() + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) + + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding( + self, + embeddings: torch.Tensor, + height: int, + width: int, + is_after_patchify: bool = False, + ) -> torch.Tensor: + num_positions = self.position_embedding.weight.shape[0] + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + if is_after_patchify: + new_height = height + new_width = width + else: + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def fetch_position_embedding_lfu_cache( + self, embeddings: torch.Tensor, h: int, w: int, max_cache: int = 20 + ): + grid = (h, w) + if grid in self.cache_position_embedding: + self.cache_position_count[grid] += 1 + return self.cache_position_embedding[grid] + + if len(self.cache_position_embedding) >= max_cache: + min_hit_grid = min( + self.cache_position_count, + key=self.cache_position_count.get, + ) + self.cache_position_count.pop(min_hit_grid) + self.cache_position_embedding.pop(min_hit_grid) + + position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True) + self.cache_position_count[grid] = 1 + self.cache_position_embedding[grid] = position_embedding + return position_embedding + + def forward( + self, + pixel_values: torch.FloatTensor, + position_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + interpolate_pos_encoding=False, + ) -> torch.Tensor: + if pixel_values.dim() == 4: + pixel_values = pixel_values.unsqueeze(0) + if pixel_values.dim() == 5: + if position_ids is None: + raise ValueError( + "position_ids cannot be None when pixel_values.dim() is 5." + ) + ( + batch_size, + squence_len, + channel, + height, + width, + ) = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + embeddings = patch_embeds.flatten(-2).squeeze(-1) + + if interpolate_pos_encoding and image_grid_thw is not None: + start = 0 + tmp_embeddings = list() + for image_grid in image_grid_thw: + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = ( + self.interpolate_pos_encoding(image_embeddings, h, w, True) + .squeeze(0) + .repeat(t, 1) + ) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) + else: + embeddings = embeddings + self.packing_position_embedding(position_ids) + return embeddings + else: + raise ValueError( + "Unsupported pixel_values dimension:" + f" {pixel_values.dim()}. Expected 4 or 5." + ) + + +def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class SiglipAttention(nn.Module): + """SigLIP vision attention adapted from Qwen2.5-VisionAttention.""" + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend_override: _Backend | None = None, + use_upstream_fa: bool = False, + ) -> None: + super().__init__() + + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size + ) + + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.attn_backend = attn_backend + self.use_upstream_fa = use_upstream_fa + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + attn_backend_override=attn_backend_override, + ) + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size) + + q, k, v = qkv.chunk(3, dim=2) + + if self.tp_size > 1: + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None, + max_seqlen: torch.Tensor | None, + seqlens: torch.Tensor | None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape + + x = rearrange(hidden_states, "b s d -> s b d") + x, _ = self.qkv_proj(x) + q, k, v = self.split_qkv(x) + q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v)) + + if rotary_pos_emb is not None: + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + + if self.is_flash_attn_backend: + if max_seqlen is None: + raise ValueError("Flash attention backend requires max_seqlen.") + context_layer = vit_flash_attn_wrapper( + q, + k, + v, + cu_seqlens, + max_seqlen, + batch_size, + self.attn_backend == _Backend.ROCM_AITER_FA, + self.use_upstream_fa, + ) + elif self.attn_backend == _Backend.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + rearrange(tensor, "b s h d -> b h s d") + for tensor in (q_i, k_i, v_i) + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() + elif self.attn_backend == _Backend.XFORMERS: + if seqlens is None: + raise ValueError("xFormers attention backend requires seqlens tensor.") + context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) + else: + raise RuntimeError( + f"PaddleOCR-VL does not support {self.attn_backend} backend now." + ) + + output, _ = self.out_proj(context_layer) + output = rearrange(output, "s b d -> b s d") + return output + + +class SigLIPRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.rope_init() + + def rope_init(self): + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class SiglipMLP(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # Special handling for BNB and torchao quantization + if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: + quantizable = True + else: + # For other quantization, we require the hidden size to be a + # multiple of 64 + quantizable = ( + config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 + ) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend_override: _Backend | None = None, + use_upstream_fa: bool = False, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + projection_size=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attn_backend=attn_backend, + attn_backend_override=attn_backend_override, + use_upstream_fa=use_upstream_fa, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None, + max_seqlen: torch.Tensor | None, + seqlens: torch.Tensor | None, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class SiglipEncoder(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_backend_override: _Backend | None = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + head_dim = embed_dim // num_heads + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + self.use_upstream_fa = False + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } and check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: + raise RuntimeError( + f"PaddleOCR-VL does not support {self.attn_backend} backend now." + ) + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + attn_backend=self.attn_backend, + attn_backend_override=attn_backend_override, + use_upstream_fa=self.use_upstream_fa, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = list() + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + def forward( + self, + inputs_embeds, + cu_seqlens: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + height_position_ids: torch.Tensor | None = None, + width_position_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + device = inputs_embeds.device + hidden_states = inputs_embeds + + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + + if width_position_ids is None or height_position_ids is None: + split_hids = list() + split_wids = list() + for t, h, w in flatten_image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack( + [height_position_ids, width_position_ids], + dim=-1, + ) + max_grid_size = pids.max() + 1 + rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rope_emb_max_grid[pids].flatten(1) + + if cu_seqlens is None: + raise ValueError("cu_seqlens cannot be None for SiglipEncoder.") + if not isinstance(cu_seqlens, torch.Tensor): + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + else: + cu_seqlens = cu_seqlens.to(device=device) + + max_seqlen = None + seqlens = None + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_backend_override: _Backend | None = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + attn_backend_override=attn_backend_override, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool | None = False, + position_ids: torch.Tensor | None = None, + height_position_ids: torch.Tensor | None = None, + width_position_ids: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, + image_grid_thw=image_grid_thw, + height_position_ids=height_position_ids, + width_position_ids=width_position_ids, + ) + + last_hidden_state = self.post_layernorm(last_hidden_state) + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_backend_override: _Backend | None = None, + ): + super().__init__() + + self.vision_model = SiglipVisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + attn_backend_override=attn_backend_override, + ) + self.quant_config = quant_config + + @property + def dtype(self) -> torch.dtype: + return self.vision_model.embeddings.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.vision_model.embeddings.patch_embedding.weight.device + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + interpolate_pos_encoding: bool = False, + position_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + cu_seqlens: torch.Tensor | None = None, + ) -> BaseModelOutputWithPooling: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + cu_seqlens=cu_seqlens, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "head.attention" in name or "head.layernorm" in name: + continue + if "head.mlp" in name or "head.probe" in name: + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for ( + param_name, + weight_name, + shard_id, + ) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@MULTIMODAL_REGISTRY.register_processor( + PaddleOCRVLMultiModalProcessor, + info=PaddleOCRVLProcessingInfo, + dummy_inputs=PaddleOCRVLDummyInputsBuilder, +) +class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsMRoPE): + merge_by_field_config = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + + self.visual = SiglipVisionModel( + config=config.vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, + ) + self.mlp_AR = Projector(config, config.vision_config) + + self.language_model = Ernie4_5ForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + for layer in self.language_model.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.is_neox_style = True + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + def get_language_model(self) -> nn.Module: + return self.language_model + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> PaddleOCRImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + return PaddleOCRImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + is_multimodal = kwargs.pop("is_multimodal", None) + handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + input_ids = None + + return self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + + raise ValueError("Only image modality is supported") + + def encode_image( + self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor + ) -> torch.Tensor: + pixel_values = pixel_values.type(self.visual.dtype) + siglip_position_ids = list() + image_grid_hws = list() + cu_seqlens = [0] + + thw_tuple = tuple(image_grid_thw.tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(image_position_ids) + cu_seqlens.append(cu_seqlens[-1] + numel) + + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( + pixel_values.device + ) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device) + + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=siglip_position_ids, + interpolate_pos_encoding=True, + cu_seqlens=cu_seqlens, + ) + return vision_outputs + + def _process_image_input( + self, image_input: PaddleOCRImagePixelInputs + ) -> MultiModalEmbeddings: + pixel_values = image_input.pixel_values + image_grid_thw = image_input.image_grid_thw + vision_outputs = tuple( + self.encode_image(pixel, grid).squeeze(0) + for pixel, grid in zip(pixel_values, image_grid_thw) + ) + image_embeds = self.mlp_AR(vision_outputs, image_grid_thw) + return image_embeds + + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return () + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + image_embeds = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeds) + + return multimodal_embeddings + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7eca1a09e5365..d9299697fcb03 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -340,6 +340,10 @@ _MULTIMODAL_MODELS = { "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), + "PaddleOCRVLForConditionalGeneration": ( + "paddleocr_vl", + "PaddleOCRVLForConditionalGeneration", + ), "PaliGemmaForConditionalGeneration": ( "paligemma", "PaliGemmaForConditionalGeneration", From 294c805f1df9ddf62c2290989710da9d48ab4973 Mon Sep 17 00:00:00 2001 From: gnovack Date: Mon, 3 Nov 2025 04:22:17 -0800 Subject: [PATCH 023/231] Early exit for MoE LoRA kernels (#27131) Signed-off-by: gnovack Co-authored-by: Jee Jee Li --- csrc/moe/moe_lora_align_sum_kernels.cu | 27 ++++++---- csrc/moe/moe_ops.h | 15 +++--- csrc/moe/torch_bindings.cpp | 4 +- tests/lora/test_fused_moe_lora_kernel.py | 6 +++ tests/lora/test_moe_lora_align_sum.py | 4 ++ tests/lora/test_olmoe_tp.py | 50 ++++++++++++++++--- vllm/_custom_ops.py | 4 ++ vllm/lora/layers/fused_moe.py | 11 +++- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 25 ++++++++-- vllm/lora/punica_wrapper/punica_base.py | 2 + vllm/lora/punica_wrapper/punica_gpu.py | 9 +++- 11 files changed, 123 insertions(+), 34 deletions(-) diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu index e76d1c3667853..360f1312cf579 100644 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel( int64_t block_size, int num_experts, int max_loras, size_t numel, int max_num_tokens_padded, int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad) { + int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* lora_ids) { const size_t tokens_per_thread = div_ceil(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - int lora_id = blockIdx.x; + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } extern __shared__ int32_t shared_mem[]; int32_t* cumsum = shared_mem; token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); @@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel( } } -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad) { +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); @@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, max_loras, topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr()); + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr()); }); } \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index e4bf0aa99421b..0adf745689b2f 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad); -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad); +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index c08a543908ef0..ace72fad71e86 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " int max_num_m_blocks, " " Tensor !sorted_token_ids," " Tensor !experts_ids," - " Tensor !num_tokens_post_pad) -> () "); + " Tensor !num_tokens_post_pad," + " Tensor !adapter_enabled," + " Tensor !lora_ids) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index b724e112b9dd3..318a0e58805d3 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel( ) expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) # call kernel ops.moe_lora_align_block_size( @@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel( sorted_token_ids, expert_ids, num_tokens_post_padded, + adapter_enabled, + lora_ids, ) config = { @@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel( num_tokens_post_padded, max_lora_rank, top_k_num, + lora_ids, + adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py index 6cd1281c36328..72f1d759f1e7a 100644 --- a/tests/lora/test_moe_lora_align_sum.py +++ b/tests/lora/test_moe_lora_align_sum.py @@ -60,6 +60,8 @@ def test_moe_lora_align_block_size( (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" ) num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") + adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda") + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") # call kernel ops.moe_lora_align_block_size( @@ -73,6 +75,8 @@ def test_moe_lora_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) # verify values diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py index b954e0776ca4a..e659c1e1a9a07 100644 --- a/tests/lora/test_olmoe_tp.py +++ b/tests/lora/test_olmoe_tp.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import vllm from vllm.lora.request import LoRARequest @@ -28,8 +29,17 @@ EXPECTED_LORA_OUTPUT = [ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 ] +EXPECTED_BASE_MODEL_OUTPUT = [ + "SELECT COUNT(Candidate_ID) FROM candidate", + "SELECT COUNT(Candidate_ID) FROM candidate", + "SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501 + "SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501 +] -def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: + +def generate_and_test( + llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None +) -> None: prompts = [ PROMPT_TEMPLATE.format(context="How many candidates are there?"), PROMPT_TEMPLATE.format(context="Count the number of candidates."), @@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: context="Return the poll resource associated with the most candidates." ), ] + + lora_request = None + if isinstance(lora_id, int): + lora_request = LoRARequest(str(lora_id), lora_id, lora_path) + elif isinstance(lora_id, list): + lora_request = [ + LoRARequest(str(i), i, lora_path) if i is not None else None + for i in lora_id + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, - ) + outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") for i in range(len(EXPECTED_LORA_OUTPUT)): - assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id + expected_output = ( + EXPECTED_LORA_OUTPUT[i] + if req_lora_id is not None + else EXPECTED_BASE_MODEL_OUTPUT[i] + ) + assert generated_texts[i].startswith(expected_output) def test_olmoe_lora(olmoe_lora_files): @@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files): generate_and_test(llm, olmoe_lora_files, lora_id=2) +def test_olmoe_lora_mixed(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) + + @multi_gpu_test(num_gpus=2) def test_olmoe_lora_tp2(olmoe_lora_files): llm = vllm.LLM( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 61cf54fcfa39a..657b11046809d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1823,6 +1823,8 @@ def moe_lora_align_block_size( sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, + adapter_enabled: torch.Tensor, + lora_ids: torch.Tensor, ) -> None: torch.ops._moe_C.moe_lora_align_block_size( topk_ids, @@ -1835,6 +1837,8 @@ def moe_lora_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 275a2ed0c6813..7711f5c3208bc 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -111,6 +111,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, max_loras, + self.adapter_enabled, expert_map, ) @@ -138,6 +139,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): max_lora_rank, top_k, config, + self.adapter_enabled, ) result = func(*args, **kwargs) @@ -196,6 +198,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): max_lora_rank, top_k, config, + self.adapter_enabled, True, ) @@ -227,6 +230,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ) -> None: """Initializes lora matrices.""" + self.adapter_enabled = torch.tensor( + [0] * (max_loras + 1), dtype=torch.int, device=self.device + ) + self.w1_lora_a_stacked = torch.zeros( ( max_loras, @@ -313,6 +320,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): self.w3_lora_b_stacked[index] = 0 self.w2_lora_a_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0 + self.adapter_enabled[index] = 0 def set_lora( self, @@ -322,8 +330,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): embeddings_tensor: torch.Tensor | None, bias: torch.Tensor | None = None, ): - self.reset_lora(index) """Overwrites lora tensors at index.""" + self.reset_lora(index) + self.adapter_enabled[index] = 1 for eid in range(len(lora_a) // 3): w1_lora_a = lora_a[eid * 3] w2_lora_a = lora_a[eid * 3 + 1] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 15031f5e2f9e8..539605c7c534a 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -54,6 +54,8 @@ def _fused_moe_lora_kernel( EM, num_valid_tokens, num_experts, + lora_ids, + adapter_enabled, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down @@ -84,6 +86,11 @@ def _fused_moe_lora_kernel( pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) + lora_id = tl.load(lora_ids + lora_idx) + moe_enabled = tl.load(adapter_enabled + lora_id) + if lora_id == -1 or moe_enabled == 0: + # Early exit for the no-lora case. + return max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) @@ -100,12 +107,12 @@ def _fused_moe_lora_kernel( pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) pid_n = (pid_m_n % num_pid_in_group) // group_size_m - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return # get the expert_id to process curr shard - ind = lora_idx * stride_el + pid_m + ind = lora_id * stride_el + pid_m expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) if expert_id == -1: return @@ -119,7 +126,7 @@ def _fused_moe_lora_kernel( offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - token_ind = stride_tl * lora_idx + offs_token_id + token_ind = stride_tl * lora_id + offs_token_id offs_token = tl.load( sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 ) @@ -132,7 +139,7 @@ def _fused_moe_lora_kernel( b_ptrs = ( cur_b_ptr - + lora_idx * stride_bl + + lora_id * stride_bl + expert_id * stride_be + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn @@ -184,6 +191,8 @@ def _fused_moe_lora( num_tokens_post_padded: torch.Tensor, # (max_loras, ) max_lora_rank: int, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, block_size_m: int, block_size_n: int, block_size_k: int, @@ -234,7 +243,7 @@ def _fused_moe_lora( num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] - lora_intermediate_cache1 = torch.empty( + lora_intermediate_cache1 = torch.zeros( (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), dtype=output.dtype, device=device, @@ -272,6 +281,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(1), w1_lora_a_stacked.stride(0), @@ -319,6 +330,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(1), w1_lora_b_stacked.stride(0), @@ -352,6 +365,8 @@ def _fused_moe_lora_fake( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, block_size_m: int, block_size_n: int, block_size_k: int, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 5b4a18cf4789b..c552412cfd62e 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -456,6 +456,7 @@ class PunicaWrapperBase(PunicaWrapperABC): block_size: int, num_experts: int, max_loras: int, + adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -479,6 +480,7 @@ class PunicaWrapperBase(PunicaWrapperABC): max_lora_rank: int, top_k_num: int, config, + adapter_enabled: torch.Tensor, mul_routed_weight=False, ): """ diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d9590769778ea..30def90380db1 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -305,6 +305,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): block_size: int, num_experts: int, max_loras: int, + adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -331,7 +332,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): (max_loras), dtype=torch.int32, device=topk_ids.device ) - (token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args( + (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( num_tokens ) @@ -346,6 +347,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): sorted_ids, expert_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) if expert_map is not None: expert_ids = expert_map[expert_ids] @@ -365,11 +368,13 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_lora_rank: int, top_k_num: int, config, + adapter_enabled: torch.Tensor, mul_routed_weight=False, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. """ + (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0)) fused_moe_lora( y, x, @@ -381,6 +386,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): num_tokens_post_padded, max_lora_rank, top_k_num, + lora_ids, + adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], From f7d2946e996f656b5f831fe2003f3b95a91fb367 Mon Sep 17 00:00:00 2001 From: pwschuurman Date: Mon, 3 Nov 2025 06:31:03 -0800 Subject: [PATCH 024/231] [Bugfix] Skip gs:// model paths for speculator detection (#27846) Signed-off-by: Peter Schuurman --- tests/transformers_utils/test_utils.py | 26 ++++++++++++++++++++++++++ vllm/engine/arg_utils.py | 10 +++++----- vllm/transformers_utils/utils.py | 8 ++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 tests/transformers_utils/test_utils.py diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py new file mode 100644 index 0000000000000..beaef04d766bf --- /dev/null +++ b/tests/transformers_utils/test_utils.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.transformers_utils.utils import is_cloud_storage, is_gcs, is_s3 + + +def test_is_gcs(): + assert is_gcs("gs://model-path") + assert not is_gcs("s3://model-path/path-to-model") + assert not is_gcs("/unix/local/path") + assert not is_gcs("nfs://nfs-fqdn.local") + + +def test_is_s3(): + assert is_s3("s3://model-path/path-to-model") + assert not is_s3("gs://model-path") + assert not is_s3("/unix/local/path") + assert not is_s3("nfs://nfs-fqdn.local") + + +def test_is_cloud_storage(): + assert is_cloud_storage("gs://model-path") + assert is_cloud_storage("s3://model-path/path-to-model") + assert not is_cloud_storage("/unix/local/path") + assert not is_cloud_storage("nfs://nfs-fqdn.local") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 66c75d944ec8b..14fd4e70ad6c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -86,7 +86,7 @@ from vllm.transformers_utils.config import ( is_interleaved, maybe_override_with_speculators, ) -from vllm.transformers_utils.utils import check_gguf_file, is_s3 +from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip @@ -1310,10 +1310,10 @@ class EngineArgs: # Check if the model is a speculator and override model/tokenizer/config # BEFORE creating ModelConfig, so the config is created with the target model - # Skip speculator detection for S3 models since HuggingFace cannot load - # configs directly from S3 URLs. S3 models can still use speculators with - # explicit --speculative-config. - if not is_s3(self.model): + # Skip speculator detection for cloud storage models (eg: S3, GCS) since + # HuggingFace cannot load configs directly from S3 URLs. S3 models can still + # use speculators with explicit --speculative-config. + if not is_cloud_storage(self.model): (self.model, self.tokenizer, self.speculative_config) = ( maybe_override_with_speculators( model=self.model, diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index af2df195f2958..1ae42ba622dc4 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -19,6 +19,14 @@ def is_s3(model_or_path: str) -> bool: return model_or_path.lower().startswith("s3://") +def is_gcs(model_or_path: str) -> bool: + return model_or_path.lower().startswith("gs://") + + +def is_cloud_storage(model_or_path: str) -> bool: + return is_s3(model_or_path) or is_gcs(model_or_path) + + def check_gguf_file(model: str | PathLike) -> bool: """Check if the file is a GGUF model.""" model = Path(model) From cac4c10ef0e3280f045bff32cbb05e9a56e41b1b Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 3 Nov 2025 08:13:51 -0800 Subject: [PATCH 025/231] [BUG] Make 'binary' default option for saving torch compile artifacts when using standalone_compile (#27616) Signed-off-by: ahao-anyscale --- docs/design/torch_compile.md | 2 ++ vllm/compilation/backends.py | 4 +++- vllm/compilation/compiler_interface.py | 9 ++++++--- vllm/config/compilation.py | 23 ++++++++++++++++++++++- vllm/envs.py | 10 ++++++++++ 5 files changed, 43 insertions(+), 5 deletions(-) diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md index 5a3ca2de82194..27edc4f89201d 100644 --- a/docs/design/torch_compile.md +++ b/docs/design/torch_compile.md @@ -27,6 +27,8 @@ With all these factors taken into consideration, usually we can guarantee that t A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes. +By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`. + ## Python Code Compilation In the very verbose logs, we can see: diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a8..83d8cdae1ed34 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -51,7 +51,9 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: and hasattr(torch._inductor, "standalone_compile") ): logger.debug("Using InductorStandaloneAdaptor") - return InductorStandaloneAdaptor() + return InductorStandaloneAdaptor( + compilation_config.compile_cache_save_format + ) else: logger.debug("Using InductorAdaptor") return InductorAdaptor() diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db941..d15481b3045d6 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -6,7 +6,7 @@ import hashlib import os from collections.abc import Callable from contextlib import ExitStack -from typing import Any +from typing import Any, Literal from unittest.mock import patch import torch @@ -175,6 +175,9 @@ class InductorStandaloneAdaptor(CompilerInterface): name = "inductor_standalone" + def __init__(self, save_format: Literal["binary", "unpacked"]): + self.save_format = save_format + def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() hash_str = hashlib.md5( @@ -220,7 +223,7 @@ class InductorStandaloneAdaptor(CompilerInterface): assert key is not None path = os.path.join(self.cache_dir, key) if not envs.VLLM_DISABLE_COMPILE_CACHE: - compiled_graph.save(path=path, format="unpacked") + compiled_graph.save(path=path, format=self.save_format) compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) @@ -237,7 +240,7 @@ class InductorStandaloneAdaptor(CompilerInterface): assert isinstance(handle[1], str) path = handle[1] inductor_compiled_graph = torch._inductor.CompiledArtifact.load( - path=path, format="unpacked" + path=path, format=self.save_format ) from torch._inductor.compile_fx import graph_returns_tuple diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6a5bd5ef4e07c..00e8cbfd7319a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -7,11 +7,12 @@ from collections import Counter from collections.abc import Callable from dataclasses import asdict, field from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Literal from pydantic import TypeAdapter, field_validator from pydantic.dataclasses import dataclass +import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import config from vllm.logger import init_logger @@ -208,6 +209,15 @@ class CompilationConfig: """The directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use model-related information to generate a cache directory.""" + compile_cache_save_format: Literal["binary", "unpacked"] = field( + default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT + ) + """Format for saving torch compile cache:\n + - "binary": saves as binary file (multiprocess safe)\n + - "unpacked": saves as directory structure for inspection/debugging + (NOT multiprocess safe)\n + Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified. + """ backend: str = "" """The backend for compilation. It needs to be a string: @@ -479,6 +489,7 @@ class CompilationConfig: factors.append(self.inductor_compile_config) factors.append(self.inductor_passes) factors.append(self.pass_config.uuid()) + factors.append(self.compile_cache_save_format) return hashlib.sha256(str(factors).encode()).hexdigest() def __repr__(self) -> str: @@ -520,6 +531,16 @@ class CompilationConfig: return CUDAGraphMode[value.upper()] return value + @field_validator("compile_cache_save_format") + @classmethod + def validate_compile_cache_save_format(cls, value: str) -> str: + if value not in ("binary", "unpacked"): + raise ValueError( + f"compile_cache_save_format must be 'binary' or 'unpacked', " + f"got: {value}" + ) + return value + def __post_init__(self) -> None: if self.level is not None: logger.warning( diff --git a/vllm/envs.py b/vllm/envs.py index 21237c70a45e4..81f189ada9a6f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -218,6 +218,7 @@ if TYPE_CHECKING: VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False + VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" def get_default_cache_root(): @@ -1442,6 +1443,15 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv( "VLLM_DISABLE_SHARED_EXPERTS_STREAM", False ), + # Format for saving torch.compile cache artifacts + # - "binary": saves as binary file + # Safe for multiple vllm serve processes accessing the same torch compile cache. + # - "unpacked": saves as directory structure (for inspection/debugging) + # NOT multiprocess safe - race conditions may occur with multiple processes. + # Allows viewing and setting breakpoints in Inductor's code output files. + "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( + "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] + ), } # --8<-- [end:env-vars-definition] From 4bc400f47e33ef27fb69608b9ad7fe992cb8ba76 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Nov 2025 02:00:46 +0900 Subject: [PATCH 026/231] [CI/Testing] Add basic single node dual batch overlap test (#27235) Signed-off-by: Lucas Wilkinson --- .buildkite/test-pipeline.yaml | 2 + tests/v1/distributed/test_dbo.py | 89 ++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 tests/v1/distributed/test_dbo.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a020b0d276be0..07e2bf09d8aa0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1223,6 +1223,7 @@ steps: - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### - label: Distributed Tests (B200) # optional @@ -1233,6 +1234,7 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min diff --git a/tests/v1/distributed/test_dbo.py b/tests/v1/distributed/test_dbo.py new file mode 100644 index 0000000000000..866ae742bf3c0 --- /dev/null +++ b/tests/v1/distributed/test_dbo.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test Dual Batch Overlap (DBO) with Data Parallelism + Expert Parallelism. + +DBO is specifically designed for DP+EP scenarios to hide communication latency +by overlapping computation of two batches. This test validates that DBO works +correctly with the DeepSeek-V2-Lite model using GSM8K evaluation. +""" + +import pytest + +from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat" +DP_SIZE = 2 + +# GSM8K eval configuration +NUM_QUESTIONS = 256 # Fast eval for CI; but must be large enough to hit dbo thresholds +NUM_SHOTS = 5 # Few-shot examples +MIN_ACCURACY = 0.62 # Expected 0.64 with 2% buffer (based on vLLM test data) + +# Increase max_num_seqs to trigger DBO for decode batches +# With 64 seqs, decode batches should exceed the 32 token threshold +MAX_NUM_SEQS = 64 # Increased from 16 to trigger decode DBO + +# DeepEP backends to test +DEEPEP_BACKENDS = [ + "deepep_low_latency", + "deepep_high_throughput", +] + + +@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS) +def test_dbo_dp_ep_gsm8k(all2all_backend: str, num_gpus_available): + """ + Test DBO with DP+EP using GSM8K evaluation. + """ + required_gpus = DP_SIZE + + if num_gpus_available < required_gpus: + pytest.skip(f"Need at least {required_gpus} GPUs (DP={DP_SIZE})") + + # Server arguments for DBO + DP + EP + server_args = [ + "--max-model-len", + "4096", + "--max-num-seqs", + str(MAX_NUM_SEQS), # Use larger batch to trigger decode DBO + "--trust-remote-code", + # Note: Not using --enforce-eager to test DBO's alternate CUDA graph dispatching + "--data-parallel-size", + str(DP_SIZE), + "--enable-expert-parallel", + "--enable-dbo", + # Fix threshold so we know we trigger DBO + "--dbo-decode-token-threshold", + "16", + "--dbo-prefill-token-threshold", + "256", + "--all2all-backend", + all2all_backend, + ] + + with RemoteOpenAIServer( + MODEL_NAME, + server_args, + max_wait_seconds=600, # Allow time for model loading with DP+EP + ) as remote_server: + # Use host and port directly from RemoteOpenAIServer + host = f"http://{remote_server.host}" + port = remote_server.port + + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=NUM_QUESTIONS, + num_shots=NUM_SHOTS, + host=host, + port=port, + ) + + # Validate accuracy is reasonable + accuracy = results["accuracy"] + assert accuracy >= MIN_ACCURACY, ( + f"DBO+DP+EP accuracy too low ({all2all_backend}): " + f"{accuracy:.3f} < {MIN_ACCURACY:.3f} " + f"(correct: {results['num_correct']}/{results['num_questions']})" + ) From 2c19d96777939dd3473eabfacbe1948a3ea0b4be Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 3 Nov 2025 09:23:31 -0800 Subject: [PATCH 027/231] [Spec Decode] Integrate Suffix Decoding from Arctic Inference (#25784) Co-authored-by: Aurick Qiao --- docs/features/spec_decode.md | 40 ++++++++++ requirements/test.in | 1 + requirements/test.txt | 2 + tests/v1/e2e/test_spec_decode.py | 85 +++++++++++++++++++-- vllm/config/speculative.py | 66 +++++++++++++++- vllm/utils/import_utils.py | 6 ++ vllm/v1/spec_decode/suffix_decoding.py | 101 +++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 +++- 8 files changed, 304 insertions(+), 11 deletions(-) create mode 100644 vllm/v1/spec_decode/suffix_decoding.py diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index ab72c7d97b7a4..6097500cac01f 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -130,6 +130,46 @@ matching n-grams in the prompt. For more information read [this thread.](https:/ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` +## Speculating using Suffix Decoding + +The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)). + +Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates. + +Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts. + +!!! tip "Install Arctic Inference" + Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`. + +!!! tip "Suffix Decoding Speculative Tokens" + Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default). + +??? code + + ```python + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 32, + }, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + ## Speculating using MLP speculators The following code configures vLLM to use speculative decoding where proposals are generated by diff --git a/requirements/test.in b/requirements/test.in index f57ec31277ce9..ce209fd276628 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 +arctic-inference == 0.1.0 # Required for suffix decoding test numba == 0.61.2 # Required for N-gram speculative decoding numpy runai-model-streamer[s3,gcs]==0.15.0 diff --git a/requirements/test.txt b/requirements/test.txt index a975f247065da..9d13fa4241152 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -40,6 +40,8 @@ anyio==4.6.2.post1 # via # httpx # starlette +arctic-inference==0.1.0 + # via -r requirements/test.in argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index ea7fcdf3174ec..9b55d2b14b991 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -75,7 +75,23 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def test_ngram_correctness( +@pytest.mark.parametrize( + "speculative_config", + [ + { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + { + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + }, + ], +) +def test_ngram_and_suffix_correctness( + speculative_config: dict, monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_name: str, @@ -94,12 +110,7 @@ def test_ngram_correctness( spec_llm = LLM( model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, + speculative_config=speculative_config, max_model_len=1024, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) @@ -121,6 +132,66 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() +def test_suffix_decoding_acceptance( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + """ + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + """ + test_prompts = get_test_prompts(mm_enabled=False) + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + }, + max_model_len=1024, + disable_log_stats=False, + ) + + # Run several times and check that the accepted tokens increase. + spec_llm.chat(test_prompts, sampling_config) + num_draft = [] + num_accept = [] + for i in range(10): # Run multiple times to warm up the cache. + spec_llm.chat(test_prompts, sampling_config) + # Collect draft and acceptance stats. + metrics = spec_llm.get_metrics() + for metric in metrics: + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 85% acceptance rate at the end. + assert last_accept_rate > 0.85 + + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + @pytest.mark.parametrize( "model_path", [ diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 1f956526dcdc6..af1d640f8accc 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -12,7 +12,7 @@ from typing_extensions import Self from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils.import_utils import LazyLoader +from vllm.utils.import_utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: from transformers import PretrainedConfig @@ -42,6 +42,7 @@ SpeculativeMethod = Literal[ "mimo_mtp", "longcat_flash_mtp", "mtp", + "suffix", ] MTP_MODEL_TYPES = ( "deepseek_mtp", @@ -129,6 +130,27 @@ class SpeculativeConfig: draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # Suffix decoding configuration + suffix_decoding_max_tree_depth: int = 24 + """The maximum depth of the suffix decoding global and prompt trees. The + tree depth limits the sum of the prefix match and speculation lengths.""" + + suffix_decoding_max_cached_requests: int = 10000 + """The maximum number of requests to cache in the global suffix tree. If + exceeded, will trigger eviction in FIFO order. If set to 0, the global + suffix tree is disabled and past responses are not cached (prompt trees + are still used).""" + + suffix_decoding_max_spec_factor: float = 1.0 + """The maximum spec factor for suffix decoding. The spec factor controls + speculation lengths based on the prefix match length: max_spec_tokens = + max_spec_factor * prefix_match_length.""" + + suffix_decoding_min_token_prob: float = 0.1 + """The minimum token probability for suffix decoding. Will only speculate + tokens with estimated probability (based on frequency counts) greater than + or equal to this value.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -235,6 +257,8 @@ class SpeculativeConfig: self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError( "num_speculative_tokens was provided but without speculative model." @@ -282,6 +306,8 @@ class SpeculativeConfig: # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -430,6 +456,42 @@ class SpeculativeConfig: ) return self + def _validate_suffix_decoding(self): + if not has_arctic_inference(): + raise ImportError( + "Arctic Inference is required for suffix decoding. " + "Install via `pip install arctic-inference==0.1.0`." + ) + if self.num_speculative_tokens is None: + # Suffix decoding decides the actual number of speculative tokens + # dynamically and treats num_speculative_tokens as a maximum limit. + self.num_speculative_tokens = self.suffix_decoding_max_tree_depth + logger.warning( + "Defaulted num_speculative_tokens to %s for suffix decoding.", + self.num_speculative_tokens, + ) + # Validate values + if self.suffix_decoding_max_tree_depth < 1: + raise ValueError( + f"suffix_decoding_max_tree_depth=" + f"{self.suffix_decoding_max_tree_depth} must be >= 1" + ) + if self.suffix_decoding_max_cached_requests < 0: + raise ValueError( + f"suffix_decoding_max_cached_requests=" + f"{self.suffix_decoding_max_cached_requests} must be >= 0" + ) + if self.suffix_decoding_max_spec_factor < 0: + raise ValueError( + f"suffix_decoding_max_spec_factor=" + f"{self.suffix_decoding_max_spec_factor} must be >= 0" + ) + if not 0 <= self.suffix_decoding_min_token_prob <= 1: + raise ValueError( + f"suffix_decoding_min_token_prob=" + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" + ) + @staticmethod def _maybe_override_draft_max_model_len( speculative_max_model_len: int | None, @@ -582,6 +644,6 @@ class SpeculativeConfig: def __repr__(self) -> str: method = self.method - model = None if method == "ngram" else self.draft_model_config.model + model = None if method in ("ngram", "suffix") else self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 409a5a6cd302d..f01d2c7a6a33d 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -403,3 +403,9 @@ def has_triton_kernels() -> bool: def has_tilelang() -> bool: """Whether the optional `tilelang` package is available.""" return _has_module("tilelang") + + +def has_arctic_inference() -> bool: + """Whether the optional `arctic_inference` package is available.""" + + return _has_module("arctic_inference") diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py new file mode 100644 index 0000000000000..049e335db3254 --- /dev/null +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config import VllmConfig +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class SuffixDecodingProposer: + """ + Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975). + This class imports and uses the official implementation from Arctic Inference + (https://github.com/snowflakedb/ArcticInference). + """ + + def __init__(self, vllm_config: VllmConfig): + config = vllm_config.speculative_config + self.num_speculative_tokens = config.num_speculative_tokens + self.max_tree_depth = config.suffix_decoding_max_tree_depth + self.max_spec_factor = config.suffix_decoding_max_spec_factor + self.min_token_prob = config.suffix_decoding_min_token_prob + self.max_model_len = vllm_config.model_config.max_model_len + + # Lazy import to avoid error when Suffix Decoding is not used. + from arctic_inference.suffix_decoding import SuffixDecodingCache + + # Initialize and empty cache. This object will take care of caching request + # outputs, evicting old requests, and manages the per-prompt suffix trees. + self.suffix_cache = SuffixDecodingCache( + max_tree_depth=config.suffix_decoding_max_tree_depth, + max_cached_requests=config.suffix_decoding_max_cached_requests, + ) + + def propose( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + """ + Propose speculative tokens for each request in the input batch. Suffix Decoding + will speculate a dynamic number of tokens for each request every decoding step, + so each entry in the returned list may have different lengths. + """ + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + if not sampled_ids: + # Skip speculative decoding for partial prefills. + draft_token_ids.append([]) + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = input_batch.req_ids[i] + if req_id in input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + num_tokens = input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + + index = input_batch.req_id_to_index[req_id] + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for this request. + self.suffix_cache.evict_cached_response(req_id) + num_prompt_tokens = input_batch.num_prompt_tokens[index] + prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] + # Start a new request, this will build the suffix tree for that prompt. + self.suffix_cache.start_request(req_id, prompt_token_ids) + + # Append the newly sampled ids to the suffix cache for this request. + self.suffix_cache.add_active_response(req_id, sampled_ids) + + # Suffix decoding only uses the most recent tokens up to max_tree_depth, so + # we extract the pattern from the end of the input. + start = max(0, num_tokens - self.max_tree_depth) + pattern = input_batch.token_ids_cpu[i, start:num_tokens] + draft = self.suffix_cache.speculate( + req_id, + pattern, + max_spec_tokens=min( + self.num_speculative_tokens, self.max_model_len - num_tokens - 1 + ), + max_spec_factor=self.max_spec_factor, + min_token_prob=self.min_token_prob, + ) + + draft_token_ids.append(draft.token_ids) + + # Stop requests that were not seen in the input batch. + for req_id in ( + self.suffix_cache.active_requests - input_batch.req_id_to_index.keys() + ): + self.suffix_cache.stop_request(req_id) + + return draft_token_ids + + def load_model(self, *args, **kwargs): + # No model to load. + pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9212221bb6009..e700c09038e28 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -125,6 +125,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.dp_utils import coordinate_batch_across_dp @@ -336,16 +337,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the last PP rank. This is not ideal if there are many # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: + self.drafter: ( + NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "suffix": + self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device - ) # type: ignore + ) else: raise ValueError( "Unknown speculative decoding method: " @@ -2783,6 +2789,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs, ) + elif self.speculative_config.method == "suffix": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, SuffixDecodingProposer) + draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) From a4398fbb5e9fe20c8f0f092da4de30c9a582cce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophie=20du=20Cou=C3=A9dic?= Date: Mon, 3 Nov 2025 19:33:17 +0100 Subject: [PATCH 028/231] [Feature][Benchmarks] Support `inf` burstiness (#26941) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sophie du Couédic --- vllm/benchmarks/serve.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 4b15d8e62913c..b8f44966db7a0 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -189,9 +189,16 @@ async def get_request( total_requests, request_rate, ) + assert current_request_rate > 0.0, ( + f"Obtained non-positive request rate {current_request_rate}." + ) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) + elif burstiness == float("inf"): + # when burstiness tends to infinity, the delay time becomes constant + # and tends to the inverse of the request rate + delay_ts.append(1.0 / current_request_rate) else: theta = 1.0 / (current_request_rate * burstiness) From 55011aef24c2838b05df585822b8fc231eea19b2 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 3 Nov 2025 11:12:15 -0800 Subject: [PATCH 029/231] [Bugfix][Qwen][Multimodal] Move Qwen2_5_vl sdpa to custom op and reenable compile (#27764) Signed-off-by: Lucas Kabela --- vllm/attention/ops/vit_attn_wrappers.py | 53 ++++++++++++++++++++++++ vllm/model_executor/models/qwen2_5_vl.py | 44 +++++++------------- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 6cefe74416685..06a9f7cd82266 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -14,6 +14,7 @@ To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) import einops import torch +import torch.nn.functional as F from vllm.utils.torch_utils import direct_register_custom_op @@ -123,3 +124,55 @@ def vit_flash_attn_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa ) + + +# TODO: Once we have a torch 2.10, we can use tensor slices +# so we won't need to wrap this in custom ops +def torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = einops.rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() + return context_layer + + +def torch_sdpa_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="torch_sdpa_wrapper", + op_func=torch_sdpa_wrapper, + fake_impl=torch_sdpa_wrapper_fake, +) + + +def vit_torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3585783e4ccc3..2b04608dfd03f 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -46,6 +46,7 @@ from vllm.attention.backends.registry import _Backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, vit_xformers_attn_wrapper, ) from vllm.compilation.decorators import support_torch_compile @@ -442,23 +443,12 @@ class Qwen2_5_VisionAttention(nn.Module): q = q.contiguous() k = k.contiguous() v = v.contiguous() - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = einops.rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = einops.rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = vit_torch_sdpa_wrapper( + q, + k, + v, + cu_seqlens, + ) elif self.attn_backend == _Backend.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) @@ -466,17 +456,15 @@ class Qwen2_5_VisionAttention(nn.Module): return output -# (FIXME): Enable this after dynamic slicing is fixed -# See https://github.com/vllm-project/vllm/pull/27760 -# @support_torch_compile( -# dynamic_arg_dims={ -# "x": 0, -# "cu_seqlens": 0, -# "rotary_pos_emb": 0, -# "seqlens": 0, -# }, -# mark_unbacked_dims={"seqlens": 0}, -# ) +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, + "seqlens": 0, + }, + mark_unbacked_dims={"seqlens": 0}, +) class Qwen2_5_VisionBlock(nn.Module): def __init__( self, From 145c00a4d32b7a681f7fb936c9575812c7aa7880 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 3 Nov 2025 15:17:10 -0500 Subject: [PATCH 030/231] [Bugfix] change FlashMLA reorder_batch_threshold (#27777) Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashmla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1f98204031ed5..bc17307532093 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -71,7 +71,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM - reorder_batch_threshold: int = 512 # process small prefills with decode pathway + reorder_batch_threshold: int = 128 # process small prefills with decode pathway # ^ TODO(matt): tune this def __init__( From 786030721efb2b85a582d65f9bb5d7197de06f83 Mon Sep 17 00:00:00 2001 From: Ning Xie Date: Tue, 4 Nov 2025 04:35:16 +0800 Subject: [PATCH 031/231] [Docs] add runai_streamer_sharded to LoadConfig (#27937) Signed-off-by: Andy Xie --- vllm/config/load.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config/load.py b/vllm/config/load.py index d625c1ac987e7..e424f8c5edb62 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -40,6 +40,8 @@ class LoadConfig: more information.\n - "runai_streamer" will load the Safetensors weights using Run:ai Model Streamer.\n + - "runai_streamer_sharded" will load weights from pre-sharded checkpoint + files using Run:ai Model Streamer.\n - "bitsandbytes" will load the weights using bitsandbytes quantization.\n - "sharded_state" will load weights from pre-sharded checkpoint files, supporting efficient loading of tensor-parallel models.\n From 01baefe674e61d156672d14b11b20055252df662 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 3 Nov 2025 16:04:40 -0500 Subject: [PATCH 032/231] Add TP parameter to attention tests (#27683) Signed-off-by: Matthew Bonanni --- .buildkite/test-pipeline.yaml | 3 +- tests/v1/attention/test_attention_backends.py | 58 +++++++++++++++++-- tests/v1/attention/test_mla_backends.py | 31 +++++++++- .../v1/attention/test_sparse_mla_backends.py | 11 +++- 4 files changed, 92 insertions(+), 11 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 07e2bf09d8aa0..4a898df8f2a34 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -347,8 +347,7 @@ steps: - vllm/v1/attention - tests/v1/attention commands: - - export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this - - pytest -v -s v1/attention + - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins source_file_dependencies: diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6659b3eb1e98f..08aeb6f298f61 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -295,6 +295,7 @@ def _test_backend_correctness( block_size: int = 16, atol: float = 1e-2, rtol: float = 1e-2, + tensor_parallel_size: int = 1, ): """ Test that all backends produce similar outputs to a reference implementation @@ -310,13 +311,38 @@ def _test_backend_correctness( 4. Running each vLLM attention backend with the new queries and the simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. """ current_platform.seed_everything(42) + + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr( + temp_config.hf_text_config, "num_key_value_heads", None + ) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + vllm_config = create_vllm_config( model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements max_model_len=max(batch_spec.seq_lens), block_size=block_size, num_gpu_blocks=8192, + hf_config_override=hf_config_override, ) device = torch.device("cuda:0") @@ -503,7 +529,10 @@ def _test_backend_correctness( ], ) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_causal_backend_correctness(batch_spec_name: str, model: str): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_causal_backend_correctness( + batch_spec_name: str, model: str, tensor_parallel_size: int +): """Test backend's correctness with causal attention.""" def causal_mask_mod( @@ -523,12 +552,23 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str): SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) + _test_backend_correctness( + batch_spec, + model, + SMALL_BLOCK_BACKENDS, + causal_mask_mod, + tensor_parallel_size=tensor_parallel_size, + ) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: _test_backend_correctness( - batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 + batch_spec, + model, + LARGE_BLOCK_BACKENDS, + causal_mask_mod, + block_size=128, + tensor_parallel_size=tensor_parallel_size, ) @@ -545,7 +585,10 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [ ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], ) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) -def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_sliding_window_backend_correctness( + batch_spec_name: str, model: str, tensor_parallel_size: int +): """Test backend's correctness with sliding window attention.""" def sliding_window_mask_mod( @@ -575,7 +618,11 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] _test_backend_correctness( - batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn + batch_spec, + model, + SMALL_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + tensor_parallel_size=tensor_parallel_size, ) # Fast FlexAttention needs to run with block_size=128 @@ -586,4 +633,5 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): LARGE_BLOCK_BACKENDS, sliding_window_mask_mod_fn, block_size=128, + tensor_parallel_size=tensor_parallel_size, ) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index cda4fb11c096e..5679fafe63ee8 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -394,8 +394,11 @@ def run_attention_backend( "spec_decode_medium", ], ) -@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) -def test_backend_correctness(dist_init, batch_spec_name: str, model: str): +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) +def test_backend_correctness( + dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int +): """ Test that all backends produce similar outputs to a reference implementation using torch.nn.functional.scaled_dot_product_attention. @@ -410,6 +413,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): 4. Running each vLLM attention backend with the new queries and the simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. """ batch_spec = BATCH_SPECS[batch_spec_name] @@ -423,11 +431,30 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Add 1 for null block at index 0, and some buffer num_gpu_blocks = required_blocks + 1 + 100 + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr( + temp_config.hf_text_config, "num_key_value_heads", None + ) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + vllm_config = create_vllm_config( model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=num_gpu_blocks, block_size=default_block_size, + hf_config_override=hf_config_override, ) # For spec decode tests, add a speculative_config to set the reorder_batch_threshold diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 02324d2aca6ef..b34d587eb362d 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) -def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_sparse_backend_decode_correctness( + dist_init, batch_name, kv_cache_dtype, tensor_parallel_size +): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -135,8 +138,11 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype total_cache_tokens = sum(batch_spec.seq_lens) block_size = 64 + # Note: We use TP=1 to avoid multi-GPU requirements in CI. + # The test simulates head partitioning via mocked methods below. vllm_config = create_vllm_config( model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", + tensor_parallel_size=1, max_model_len=max_seqlen, num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), block_size=block_size, @@ -156,7 +162,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype ) model_config.dtype = dtype model_config.get_num_attention_heads = MethodType( - lambda self, parallel_config: num_heads, model_config + lambda self, parallel_config: max(1, num_heads // tensor_parallel_size), + model_config, ) model_config.get_num_kv_heads = MethodType( lambda self, parallel_config: 1, model_config From ccd3e55e51d44bf3a17b2203a304c9609aa5dfe2 Mon Sep 17 00:00:00 2001 From: Hank_ <37239608+ILikeIneine@users.noreply.github.com> Date: Tue, 4 Nov 2025 05:27:03 +0800 Subject: [PATCH 033/231] [Bugfix][plugin] fla crash on plugin (#27322) --- vllm/model_executor/layers/fla/ops/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 3a503981a8734..5a48e56a5fbbf 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -17,6 +17,7 @@ from typing import Any, Literal import torch +from vllm.platforms import current_platform from vllm.triton_utils import triton logger = logging.getLogger(__name__) @@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. # Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != "hip" else "cuda" -device_torch_lib = getattr(torch, device) +device = "cuda" if current_platform.is_cuda_alike() else get_available_device() +device_torch_lib = getattr(torch, device, None) device_platform = _check_platform() is_amd = device_platform == "amd" From 3758757377b713b6acc997d0ac2c5dd49c332278 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 3 Nov 2025 17:26:49 -0500 Subject: [PATCH 034/231] [Bugfix] Fix MoE Routing Simulation (#28002) Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/routing_simulator.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 46d351b48c5e8..55aa2593193ab 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2066,7 +2066,7 @@ class FusedMoE(CustomOp): ) # DeepSeekv2 uses grouped_top_k - if use_grouped_topk: + elif use_grouped_topk: assert topk_group is not None assert num_expert_group is not None if is_rocm_aiter_moe_enabled(): diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index 8b04cf4539e04..a01cdc4908b93 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -14,6 +14,10 @@ from typing import Any import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + class RoutingStrategy(ABC): """Base class for token-to-expert routing strategies.""" @@ -290,6 +294,12 @@ class RoutingSimulator: f"Available strategies: " f"{list(RoutingSimulator._routing_strategies.keys())}" ) + logger.warning_once( + "Simulating MoE routing using a %s strategy. " + "This should only be used for performance testing. " + "Model outputs will not be valid.", + strategy_name, + ) strategy = RoutingSimulator._routing_strategies[strategy_name] return strategy.route_tokens( From 7956b0c0bca8c2b778e6a0b18953b5a08e136c90 Mon Sep 17 00:00:00 2001 From: QiliangCui Date: Mon, 3 Nov 2025 16:35:54 -0800 Subject: [PATCH 035/231] Remove the tpu docker image nightly build. (#27997) Signed-off-by: Qiliang Cui --- .buildkite/release-pipeline.yaml | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 33b7114666fa2..12f730738b8a5 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -116,24 +116,6 @@ steps: commands: - "bash .buildkite/scripts/annotate-release.sh" - - label: "Build and publish TPU release image" - depends_on: ~ - if: build.env("NIGHTLY") == "1" - agents: - queue: tpu_queue_postmerge - commands: - - "yes | docker system prune -a" - - "git fetch --all" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ." - - "docker push vllm/vllm-tpu:nightly" - - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" - plugins: - - docker-login#v3.0.0: - username: vllmbot - password-env: DOCKERHUB_TOKEN - env: - DOCKER_BUILDKIT: "1" - - input: "Provide Release version here" id: input-release-version fields: From b13a44754674a0056d7c8113deb33ea858f6ef1c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 09:12:19 +0800 Subject: [PATCH 036/231] [Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (#27748) Signed-off-by: vllmellm --- vllm/model_executor/layers/rotary_embedding/common.py | 11 +++++++---- vllm/model_executor/models/glm4_1v.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 9e6ec9fdd523c..196533b617959 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -77,7 +77,11 @@ def dispatch_rotary_emb_function( if current_platform.is_cuda(): return apply_rotary_emb - if current_platform.is_rocm(): + # if torch compile is not enabled + # use rotary embedding function from flash_attn package + # otherwise use the naive pytorch embedding implementation + # is faster when torch compile is enabled. + if current_platform.is_rocm() and not torch.compiler.is_compiling(): if find_spec("flash_attn") is not None: from flash_attn.ops.triton.rotary import apply_rotary @@ -87,11 +91,10 @@ def dispatch_rotary_emb_function( "flash_attn is not installed. Falling back to PyTorch " "implementation for rotary embeddings." ) - if default is not None: return default - else: - return apply_rotary_emb_torch + + return apply_rotary_emb_torch # yarn functions diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 3e243385fd049..121e84469c52f 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -370,7 +370,7 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0, + dropout_p=0.0, causal=False, ) From 6ddae74054d4d9b55b367bfc9db82969f9d81930 Mon Sep 17 00:00:00 2001 From: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:30:20 -0800 Subject: [PATCH 037/231] [LoRA] Lora shrink swizzle (#27694) Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li Co-authored-by: Jee Jee Li --- vllm/lora/ops/triton_ops/lora_shrink_op.py | 15 +++++++++++++-- vllm/lora/ops/triton_ops/utils.py | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8d126197f83ea..adc5c9dce5e84 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -41,6 +41,7 @@ def _lora_shrink_kernel( BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SLICE_NUM: tl.constexpr, ): cta_n_num = tl.cdiv(N, BLOCK_N) @@ -48,8 +49,16 @@ def _lora_shrink_kernel( pid_sk_m_n = tl.program_id(axis=0) pid_sk = pid_sk_m_n % SPLIT_K - pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num - pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + pid_m_n = pid_sk_m_n // SPLIT_K + num_pid_in_group = GROUP_SIZE_M * cta_n_num + group_id = pid_m_n // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M) + + # Column-major ordering within groups for better cache reuse + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) @@ -194,6 +203,7 @@ def _lora_shrink( NUM_WARPS = kernel_config["num_warps"] NUM_STAGES = kernel_config["num_stages"] NUM_CTAS = kernel_config["num_ctas"] + GROUP_SIZE_M = kernel_config.get("group_size_m", 8) EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore # TODO (varun): This grid formulation maximizes parallelization at the @@ -233,6 +243,7 @@ def _lora_shrink( BLOCK_K, EVEN_K, SPLIT_K, + GROUP_SIZE_M, NUM_SLICES, num_warps=NUM_WARPS, num_ctas=NUM_CTAS, diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 9ffb6dc3d85e5..368c5037d2e4d 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -199,6 +199,7 @@ def get_lora_op_configs( "split_k": 64 if batch < 128 else 8, "num_warps": 4, "num_ctas": 1, + "group_size_m": 8, "num_stages": 2, "max_nreg": None, } From c02fccdbd2794fe016ebd738e3a8b8c8d78eb78c Mon Sep 17 00:00:00 2001 From: Chauncey Date: Tue, 4 Nov 2025 10:10:10 +0800 Subject: [PATCH 038/231] [Refactor] Lazy import tool_parser (#27974) Signed-off-by: chaunceyjiang --- docs/features/tool_calling.md | 7 +- .../tool_use/test_deepseekv31_tool_parser.py | 4 +- .../tool_use/test_ernie45_moe_tool_parser.py | 2 +- tests/tool_use/test_glm4_moe_tool_parser.py | 4 +- tests/tool_use/test_jamba_tool_parser.py | 2 +- tests/tool_use/test_kimi_k2_tool_parser.py | 2 +- tests/tool_use/test_minimax_tool_parser.py | 2 +- tests/tool_use/test_openai_tool_parser.py | 2 +- tests/tool_use/test_seed_oss_tool_parser.py | 2 +- tests/tool_use/test_xlam_tool_parser.py | 2 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/cli_args.py | 2 +- .../openai/tool_parsers/__init__.py | 195 +++++++++++++----- .../tool_parsers/abstract_tool_parser.py | 142 +++++++++---- .../tool_parsers/deepseekv31_tool_parser.py | 2 - .../tool_parsers/deepseekv3_tool_parser.py | 2 - .../tool_parsers/ernie45_tool_parser.py | 2 - .../tool_parsers/glm4_moe_tool_parser.py | 2 - .../granite_20b_fc_tool_parser.py | 2 - .../tool_parsers/granite_tool_parser.py | 2 - .../openai/tool_parsers/hermes_tool_parser.py | 2 - .../tool_parsers/hunyuan_a13b_tool_parser.py | 2 - .../tool_parsers/internlm2_tool_parser.py | 2 - .../openai/tool_parsers/jamba_tool_parser.py | 3 +- .../tool_parsers/kimi_k2_tool_parser.py | 2 - .../llama4_pythonic_tool_parser.py | 2 - .../openai/tool_parsers/llama_tool_parser.py | 3 - .../tool_parsers/longcat_tool_parser.py | 2 - .../tool_parsers/minimax_m2_tool_parser.py | 2 - .../tool_parsers/minimax_tool_parser.py | 2 - .../tool_parsers/mistral_tool_parser.py | 2 - .../openai/tool_parsers/olmo3_tool_parser.py | 2 - .../openai/tool_parsers/openai_tool_parser.py | 2 - .../tool_parsers/phi4mini_tool_parser.py | 2 - .../tool_parsers/pythonic_tool_parser.py | 2 - .../tool_parsers/qwen3coder_tool_parser.py | 2 - .../tool_parsers/qwen3xml_tool_parser.py | 2 - .../tool_parsers/seed_oss_tool_parser.py | 2 - .../openai/tool_parsers/step3_tool_parser.py | 2 - .../openai/tool_parsers/xlam_tool_parser.py | 2 - 40 files changed, 266 insertions(+), 158 deletions(-) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 7a1b30096a56d..7e6c69e717dba 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -407,7 +407,6 @@ Here is a summary of a plugin file: # the name list in register_module can be used # in --tool-call-parser. you can define as many # tool parsers as you want here. - @ToolParserManager.register_module(["example"]) class ExampleToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -439,6 +438,12 @@ Here is a summary of a plugin file: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text) + # register the tool parser to ToolParserManager + ToolParserManager.register_lazy_module( + name="example", + module_path="vllm.entrypoints.openai.tool_parsers.example", + class_name="ExampleToolParser", + ) ``` diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_use/test_deepseekv31_tool_parser.py index 9b7e71b49c05b..db5168071fbce 100644 --- a/tests/tool_use/test_deepseekv31_tool_parser.py +++ b/tests/tool_use/test_deepseekv31_tool_parser.py @@ -3,7 +3,9 @@ import pytest -from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser +from vllm.entrypoints.openai.tool_parsers.deepseekv31_tool_parser import ( + DeepSeekV31ToolParser, +) from vllm.transformers_utils.tokenizer import get_tokenizer MODEL = "deepseek-ai/DeepSeek-V3.1" diff --git a/tests/tool_use/test_ernie45_moe_tool_parser.py b/tests/tool_use/test_ernie45_moe_tool_parser.py index 0862d14812d72..fb5af6e13a96b 100644 --- a/tests/tool_use/test_ernie45_moe_tool_parser.py +++ b/tests/tool_use/test_ernie45_moe_tool_parser.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser +from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py index 6f1f6671d9b3c..f545f52c02dcb 100644 --- a/tests/tool_use/test_glm4_moe_tool_parser.py +++ b/tests/tool_use/test_glm4_moe_tool_parser.py @@ -7,7 +7,9 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser +from vllm.entrypoints.openai.tool_parsers.glm4_moe_tool_parser import ( + Glm4MoeModelToolParser, +) from vllm.transformers_utils.tokenizer import get_tokenizer pytestmark = pytest.mark.cpu_test diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 6dcdd5ba2ce76..9eb73b80fa9b4 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -9,7 +9,7 @@ import pytest from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import JambaToolParser +from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py index 43b8c70acbfc3..c358589dbc292 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -7,7 +7,7 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser +from vllm.entrypoints.openai.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser from vllm.transformers_utils.tokenizer import get_tokenizer pytestmark = pytest.mark.cpu_test diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py index 8610656fa288d..4332984083dab 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser +from vllm.entrypoints.openai.tool_parsers.minimax_tool_parser import MinimaxToolParser from vllm.transformers_utils.tokenizer import get_tokenizer pytestmark = pytest.mark.cpu_test diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py index f6223f3fdce4f..c874a9601ae70 100644 --- a/tests/tool_use/test_openai_tool_parser.py +++ b/tests/tool_use/test_openai_tool_parser.py @@ -15,7 +15,7 @@ from openai_harmony import ( ) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser +from vllm.entrypoints.openai.tool_parsers.openai_tool_parser import OpenAIToolParser from vllm.transformers_utils.tokenizer import get_tokenizer MODEL = "gpt2" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index 1133b949f2270..1367ad87cb019 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -14,7 +14,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser +from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 8c27b2911f8f9..122b427d60409 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import xLAMToolParser +from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c37aba2776aeb..e184f22f36307 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1943,7 +1943,7 @@ def create_server_unix_socket(path: str) -> socket.socket: def validate_api_server_args(args): - valid_tool_parses = ToolParserManager.tool_parsers.keys() + valid_tool_parses = ToolParserManager.list_registered() if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses: raise KeyError( f"invalid tool call parser: {args.tool_call_parser} " diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 1a775d3d68094..476587c178237 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -219,7 +219,7 @@ class FrontendArgs: frontend_kwargs["middleware"]["default"] = [] # Special case: Tool call parser shows built-in options. - valid_tool_parsers = list(ToolParserManager.tool_parsers.keys()) + valid_tool_parsers = list(ToolParserManager.list_registered()) parsers_str = ",".join(valid_tool_parsers) frontend_kwargs["tool_call_parser"]["metavar"] = ( f"{{{parsers_str}}} or name registered in --tool-parser-plugin" diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 4541ca50822f7..7038d4c1f05cc 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,61 +1,142 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .abstract_tool_parser import ToolParser, ToolParserManager -from .deepseekv3_tool_parser import DeepSeekV3ToolParser -from .deepseekv31_tool_parser import DeepSeekV31ToolParser -from .ernie45_tool_parser import Ernie45ToolParser -from .glm4_moe_tool_parser import Glm4MoeModelToolParser -from .granite_20b_fc_tool_parser import Granite20bFCToolParser -from .granite_tool_parser import GraniteToolParser -from .hermes_tool_parser import Hermes2ProToolParser -from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser -from .internlm2_tool_parser import Internlm2ToolParser -from .jamba_tool_parser import JambaToolParser -from .kimi_k2_tool_parser import KimiK2ToolParser -from .llama4_pythonic_tool_parser import Llama4PythonicToolParser -from .llama_tool_parser import Llama3JsonToolParser -from .longcat_tool_parser import LongcatFlashToolParser -from .minimax_m2_tool_parser import MinimaxM2ToolParser -from .minimax_tool_parser import MinimaxToolParser -from .mistral_tool_parser import MistralToolParser -from .olmo3_tool_parser import Olmo3PythonicToolParser -from .openai_tool_parser import OpenAIToolParser -from .phi4mini_tool_parser import Phi4MiniJsonToolParser -from .pythonic_tool_parser import PythonicToolParser -from .qwen3coder_tool_parser import Qwen3CoderToolParser -from .qwen3xml_tool_parser import Qwen3XMLToolParser -from .seed_oss_tool_parser import SeedOssToolParser -from .step3_tool_parser import Step3ToolParser -from .xlam_tool_parser import xLAMToolParser +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) -__all__ = [ - "ToolParser", - "ToolParserManager", - "Granite20bFCToolParser", - "GraniteToolParser", - "Hermes2ProToolParser", - "MistralToolParser", - "Internlm2ToolParser", - "Llama3JsonToolParser", - "JambaToolParser", - "Llama4PythonicToolParser", - "LongcatFlashToolParser", - "PythonicToolParser", - "Phi4MiniJsonToolParser", - "DeepSeekV3ToolParser", - "DeepSeekV31ToolParser", - "Ernie45ToolParser", - "xLAMToolParser", - "Olmo3PythonicToolParser", - "MinimaxToolParser", - "KimiK2ToolParser", - "HunyuanA13BToolParser", - "Glm4MoeModelToolParser", - "Qwen3CoderToolParser", - "Qwen3XMLToolParser", - "SeedOssToolParser", - "Step3ToolParser", - "OpenAIToolParser", - "MinimaxM2ToolParser", -] +__all__ = ["ToolParser", "ToolParserManager"] + + +""" +Register a lazy module mapping. + +Example: + ToolParserManager.register_lazy_module( + name="kimi_k2", + module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", + class_name="KimiK2ToolParser", + ) +""" + + +_TOOL_PARSERS_TO_REGISTER = { + "deepseek_v3": ( # name + "deepseekv3_tool_parser", # filename + "DeepSeekV3ToolParser", # class_name + ), + "deepseek_v31": ( + "deepseekv31_tool_parser", + "DeepSeekV31ToolParser", + ), + "ernie45": ( + "ernie45_tool_parser", + "Ernie45ToolParser", + ), + "glm45": ( + "glm4_moe_tool_parser", + "Glm4MoeModelToolParser", + ), + "granite-20b-fc": ( + "granite_20b_fc_tool_parser", + "Granite20bFCToolParser", + ), + "granite": ( + "granite_tool_parser", + "GraniteToolParser", + ), + "hermes": ( + "hermes_tool_parser", + "Hermes2ProToolParser", + ), + "hunyuan_a13b": ( + "hunyuan_a13b_tool_parser", + "HunyuanA13BToolParser", + ), + "internlm": ( + "internlm2_tool_parser", + "Internlm2ToolParser", + ), + "jamba": ( + "jamba_tool_parser", + "JambaToolParser", + ), + "kimi_k2": ( + "kimi_k2_tool_parser", + "KimiK2ToolParser", + ), + "llama3_json": ( + "llama_tool_parser", + "Llama3JsonToolParser", + ), + "llama4_json": ( + "llama_tool_parser", + "Llama4JsonToolParser", + ), + "llama4_pythonic": ( + "llama4_pythonic_tool_parser", + "Llama4PythonicToolParser", + ), + "longcat": ( + "longcat_tool_parser", + "LongcatFlashToolParser", + ), + "minimax_m2": ( + "minimax_m2_tool_parser", + "MinimaxM2ToolParser", + ), + "minimax": ( + "minimax_tool_parser", + "MinimaxToolParser", + ), + "mistral": ( + "mistral_tool_parser", + "MistralToolParser", + ), + "olmo3": ( + "olmo3_tool_parser", + "Olmo3PythonicToolParser", + ), + "openai": ( + "openai_tool_parser", + "OpenAIToolParser", + ), + "phi4_mini_json": ( + "phi4mini_tool_parser", + "Phi4MiniJsonToolParser", + ), + "pythonic": ( + "pythonic_tool_parser", + "PythonicToolParser", + ), + "qwen3_coder": ( + "qwen3coder_tool_parser", + "Qwen3CoderToolParser", + ), + "qwen3_xml": ( + "qwen3xml_tool_parser", + "Qwen3XmlToolParser", + ), + "seed_oss": ( + "seed_oss_tool_parser", + "SeedOsSToolParser", + ), + "step3": ( + "step3_tool_parser", + "Step3ToolParser", + ), + "xlam": ( + "xlam_tool_parser", + "xLAMToolParser", + ), +} + + +def register_lazy_tool_parsers(): + for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items(): + module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}" + ToolParserManager.register_lazy_module(name, module_path, class_name) + + +register_lazy_tool_parsers() diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 212326fdafb1e..8d520f5bf8ef6 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import os from collections.abc import Callable, Sequence from functools import cached_property @@ -99,89 +100,158 @@ class ToolParser: class ToolParserManager: - tool_parsers: dict[str, type] = {} + """ + Central registry for ToolParser implementations. + + Supports two modes: + - Eager (immediate) registration via `register_module` + - Lazy registration via `register_lazy_module` + """ + + tool_parsers: dict[str, type[ToolParser]] = {} + lazy_parsers: dict[str, tuple[str, str]] = {} # name -> (module_path, class_name) @classmethod - def get_tool_parser(cls, name) -> type: + def get_tool_parser(cls, name: str) -> type[ToolParser]: """ - Get tool parser by name which is registered by `register_module`. + Retrieve a registered or lazily registered ToolParser class. - Raise a KeyError exception if the name is not registered. + If the parser is lazily registered, + it will be imported and cached on first access. + Raises KeyError if not found. """ if name in cls.tool_parsers: return cls.tool_parsers[name] - raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + if name in cls.lazy_parsers: + return cls._load_lazy_parser(name) + + raise KeyError(f"Tool parser '{name}' not found.") + + @classmethod + def _load_lazy_parser(cls, name: str) -> type[ToolParser]: + """Import and register a lazily loaded parser.""" + module_path, class_name = cls.lazy_parsers[name] + try: + mod = importlib.import_module(module_path) + parser_cls = getattr(mod, class_name) + if not issubclass(parser_cls, ToolParser): + raise TypeError( + f"{class_name} in {module_path} is not a ToolParser subclass." + ) + cls.tool_parsers[name] = parser_cls # cache + return parser_cls + except Exception as e: + logger.exception( + "Failed to import lazy tool parser '%s' from %s: %s", + name, + module_path, + e, + ) + raise @classmethod def _register_module( cls, - module: type, + module: type[ToolParser], module_name: str | list[str] | None = None, force: bool = True, ) -> None: + """Register a ToolParser class immediately.""" if not issubclass(module, ToolParser): raise TypeError( f"module must be subclass of ToolParser, but got {type(module)}" ) + if module_name is None: module_name = module.__name__ + if isinstance(module_name, str): - module_name = [module_name] - for name in module_name: + module_names = [module_name] + elif is_list_of(module_name, str): + module_names = module_name + else: + raise TypeError("module_name must be str, list[str], or None.") + + for name in module_names: if not force and name in cls.tool_parsers: - existed_module = cls.tool_parsers[name] - raise KeyError( - f"{name} is already registered at {existed_module.__module__}" - ) + existed = cls.tool_parsers[name] + raise KeyError(f"{name} is already registered at {existed.__module__}") cls.tool_parsers[name] = module + @classmethod + def register_lazy_module(cls, name: str, module_path: str, class_name: str) -> None: + """ + Register a lazy module mapping. + + Example: + ToolParserManager.register_lazy_module( + name="kimi_k2", + module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", + class_name="KimiK2ToolParser", + ) + """ + cls.lazy_parsers[name] = (module_path, class_name) + @classmethod def register_module( cls, name: str | list[str] | None = None, force: bool = True, - module: type | None = None, - ) -> type | Callable: + module: type[ToolParser] | None = None, + ) -> type[ToolParser] | Callable[[type[ToolParser]], type[ToolParser]]: """ - Register module with the given name or name list. it can be used as a - decoder(with module as None) or normal function(with module as not - None). + Register module immediately or lazily (as a decorator). + + Usage: + @ToolParserManager.register_module("kimi_k2") + class KimiK2ToolParser(ToolParser): + ... + + Or: + ToolParserManager.register_module(module=SomeToolParser) """ if not isinstance(force, bool): raise TypeError(f"force must be a boolean, but got {type(force)}") - # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_list_of(name, str)): - raise TypeError( - "name must be None, an instance of str, or a sequence of str, " - f"but got {type(name)}" - ) - - # use it as a normal method: x.register_module(module=SomeClass) + # Immediate registration if module is not None: cls._register_module(module=module, module_name=name, force=force) return module - # use it as a decorator: @x.register_module() - def _register(module): - cls._register_module(module=module, module_name=name, force=force) - return module + # Decorator usage + def _decorator(obj: type[ToolParser]) -> type[ToolParser]: + module_path = obj.__module__ + class_name = obj.__name__ - return _register + if isinstance(name, str): + names = [name] + elif is_list_of(name, str): + names = name + else: + names = [class_name] + + for n in names: + # Lazy mapping only: do not import now + cls.lazy_parsers[n] = (module_path, class_name) + + return obj + + return _decorator + + @classmethod + def list_registered(cls) -> list[str]: + """Return names of all eagerly and lazily registered tool parsers.""" + return sorted(set(cls.tool_parsers.keys()) | set(cls.lazy_parsers.keys())) @classmethod def import_tool_parser(cls, plugin_path: str) -> None: - """ - Import a user-defined tool parser by the path of the tool parser define - file. - """ - module_name = os.path.splitext(os.path.basename(plugin_path))[0] + """Import a user-defined parser file from arbitrary path.""" + module_name = os.path.splitext(os.path.basename(plugin_path))[0] try: import_from_path(module_name, plugin_path) except Exception: logger.exception( "Failed to load module '%s' from %s.", module_name, plugin_path ) - return diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py index 14fd5cf0941c6..cbeb879969ece 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -17,7 +17,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +24,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("deepseek_v31") class DeepSeekV31ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index b256560fb4beb..bf7f6fa61ab90 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -17,7 +17,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +24,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("deepseek_v3") class DeepSeekV3ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py index e4696334eb135..82370323cb00d 100644 --- a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py @@ -17,7 +17,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +24,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("ernie45") class Ernie45ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): """ diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 5081b38240ce6..120e63b929b16 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,7 +27,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("glm45") class Glm4MoeModelToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index c5246685f4071..ae9217426fb51 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -21,7 +21,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import ( consume_space, @@ -35,7 +34,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("granite-20b-fc") class Granite20bFCToolParser(ToolParser): """ Tool call parser for the granite-20b-functioncalling model intended diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index cc1f500342353..d29c427694dc9 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import ( consume_space, @@ -33,7 +32,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("granite") class GraniteToolParser(ToolParser): """ Tool call parser for the granite 3.0 models. Intended diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 6332de42f424e..4336a5438109f 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -28,7 +27,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index b32e6e39b3e5c..920675c8389b8 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger @@ -29,7 +28,6 @@ from vllm.utils import random_uuid logger = init_logger(__name__) -@ToolParserManager.register_module("hunyuan_a13b") class HunyuanA13BToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index c87bab4353b5b..1dd327f645b3a 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger @@ -28,7 +27,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["internlm"]) class Internlm2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 21ee2b762cd0a..6f53ddea4f0ef 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -18,7 +18,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -27,7 +27,6 @@ from vllm.transformers_utils.tokenizers import MistralTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("jamba") class JambaToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 3fff3b371dbe3..0453db58361a9 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -17,7 +17,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +24,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["kimi_k2"]) class KimiK2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index dd622b69525de..1d6de9244066e 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger @@ -31,7 +30,6 @@ class _UnexpectedAstError(Exception): pass -@ToolParserManager.register_module("llama4_pythonic") class Llama4PythonicToolParser(ToolParser): """ Toolcall parser for Llama4 that produce tool calls in a pythonic style diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 8c7b3cefb200e..02fc9b8a4d34e 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -21,7 +21,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import ( find_common_prefix, @@ -33,8 +32,6 @@ from vllm.logger import init_logger logger = init_logger(__name__) -@ToolParserManager.register_module("llama3_json") -@ToolParserManager.register_module("llama4_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.x and 4 models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py index 1dc1a0290c8d9..c6c8ae8ae95f1 100644 --- a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py @@ -3,12 +3,10 @@ import regex as re -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParserManager from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.transformers_utils.tokenizer import AnyTokenizer -@ToolParserManager.register_module("longcat") class LongcatFlashToolParser(Hermes2ProToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py index d083ece892d50..05f4826028c12 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -27,7 +26,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("minimax_m2") class MinimaxM2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 4b12bf68b3670..982518a52e3da 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger @@ -28,7 +27,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("minimax") class MinimaxToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index dbdf0085367bc..85671271522d3 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -22,7 +22,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger @@ -53,7 +52,6 @@ def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: ) -@ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with diff --git a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py index ed5633aac02d4..baff33bd7e8ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger @@ -31,7 +30,6 @@ class _UnexpectedAstError(Exception): pass -@ToolParserManager.register_module("olmo3") class Olmo3PythonicToolParser(ToolParser): """ Tool call parser for Olmo 3 models that produce tool calls as diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py index f44876943ac28..d1b36a297e0b1 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -14,7 +14,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger @@ -26,7 +25,6 @@ else: logger = init_logger(__name__) -@ToolParserManager.register_module("openai") class OpenAIToolParser(ToolParser): def __init__(self, tokenizer: "AnyTokenizer"): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index a8387ba1494df..acb25ea2768e1 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -18,14 +18,12 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger logger = init_logger(__name__) -@ToolParserManager.register_module("phi4_mini_json") class Phi4MiniJsonToolParser(ToolParser): """ Tool call parser for phi-4-mini models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 4945e7b5ab20a..abeb923b93227 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -21,7 +21,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger @@ -32,7 +31,6 @@ class _UnexpectedAstError(Exception): pass -@ToolParserManager.register_module("pythonic") class PythonicToolParser(ToolParser): """ Tool call parser for models that produce tool calls in a pythonic style, diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index ad56972e6387e..26261c0065ead 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,7 +27,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py index 9964d1ac25c40..cf2fa30d01547 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py @@ -21,7 +21,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -1165,7 +1164,6 @@ class StreamingXMLToolCallParser: self.deferred_param_raw_value = "" -@ToolParserManager.register_module("qwen3_xml") class Qwen3XMLToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py index f50a2df53bc04..8aed7f0e9fc96 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -31,7 +30,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module("seed_oss") class SeedOssToolParser(ToolParser): TOOL_CALL_START = "" TOOL_CALL_END = "" diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index d0255ec085391..adcb9f4765473 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,7 +27,6 @@ from vllm.utils import random_uuid logger = init_logger(__name__) -@ToolParserManager.register_module(["step3"]) class Step3ToolParser(ToolParser): """ Tool parser for a model that uses a specific XML-like format for tool calls. diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index c1f0d29cc0873..9d308af4de601 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -19,7 +19,6 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, - ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,7 +27,6 @@ from vllm.utils import random_uuid logger = init_logger(__name__) -@ToolParserManager.register_module("xlam") class xLAMToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) From 14a125a06df7275923fe9748f67e27e449412d1f Mon Sep 17 00:00:00 2001 From: liuzhenwei Date: Tue, 4 Nov 2025 11:28:35 +0800 Subject: [PATCH 039/231] [NIXL][XPU] Pin NIXL version to 0.7.0 (#27849) Signed-off-by: zhenwei-intel --- tools/install_nixl_from_source_ubuntu.py | 31 ++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py index 742aab6b0de75..4a20b6b7bb8fb 100644 --- a/tools/install_nixl_from_source_ubuntu.py +++ b/tools/install_nixl_from_source_ubuntu.py @@ -3,9 +3,11 @@ # install_prerequisites.py import argparse import glob +import json import os import subprocess import sys +import urllib.request # --- Configuration --- WHEELS_CACHE_HOME = os.environ.get("WHEELS_CACHE_HOME", "/tmp/wheels_cache") @@ -18,6 +20,20 @@ NIXL_REPO_URL = "https://github.com/ai-dynamo/nixl.git" # --- Helper Functions --- +def get_latest_nixl_version(): + """Helper function to get latest release version of NIXL""" + try: + nixl_release_url = "https://api.github.com/repos/ai-dynamo/nixl/releases/latest" + with urllib.request.urlopen(nixl_release_url) as response: + data = json.load(response) + return data.get("tag_name", "0.7.0") + except Exception: + return "0.7.0" + + +NIXL_VERSION = os.environ.get("NIXL_VERSION", get_latest_nixl_version()) + + def run_command(command, cwd=".", env=None): """Helper function to run a shell command and check for errors.""" print(f"--> Running command: {' '.join(command)} in '{cwd}'", flush=True) @@ -37,7 +53,7 @@ def is_pip_package_installed(package_name): def find_nixl_wheel_in_cache(cache_dir): """Finds a nixl wheel file in the specified cache directory.""" # The repaired wheel will have a 'manylinux' tag, but this glob still works. - search_pattern = os.path.join(cache_dir, "nixl*.whl") + search_pattern = os.path.join(cache_dir, f"nixl*{NIXL_VERSION}*.whl") wheels = glob.glob(search_pattern) if wheels: # Sort to get the most recent/highest version if multiple exist @@ -146,6 +162,10 @@ def build_and_install_prerequisites(args): print("\n[2/3] Building NIXL wheel from source...", flush=True) if not os.path.exists(NIXL_DIR): run_command(["git", "clone", NIXL_REPO_URL, NIXL_DIR]) + else: + run_command(["git", "fetch", "--tags"], cwd=NIXL_DIR) + run_command(["git", "checkout", NIXL_VERSION], cwd=NIXL_DIR) + print(f"--> Checked out NIXL version: {NIXL_VERSION}", flush=True) build_env = os.environ.copy() build_env["PKG_CONFIG_PATH"] = os.path.join(ucx_install_path, "lib", "pkgconfig") @@ -203,7 +223,14 @@ def build_and_install_prerequisites(args): {os.path.basename(newly_built_wheel)}. Now installing...", flush=True, ) - install_command = [sys.executable, "-m", "pip", "install", newly_built_wheel] + install_command = [ + sys.executable, + "-m", + "pip", + "install", + "--no-deps", # w/o "no-deps", it will install cuda-torch + newly_built_wheel, + ] if args.force_reinstall: install_command.insert(-1, "--force-reinstall") From 380ba6816d4646be99d9b6d207ba7bc7fce8290e Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 4 Nov 2025 04:35:36 +0000 Subject: [PATCH 040/231] [Metrics] Enable sleep state metric outside of dev mode (#27867) Signed-off-by: Mark McLoughlin --- vllm/v1/metrics/loggers.py | 50 ++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 67b6ceaa847f6..e85f85bfb0aab 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,7 +9,6 @@ from typing import TypeAlias from prometheus_client import Counter, Gauge, Histogram -import vllm.envs as envs from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorLogging, @@ -395,32 +394,32 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.gauge_scheduler_waiting = make_per_engine( gauge_scheduler_waiting, engine_indexes, model_name ) - if envs.VLLM_SERVER_DEV_MODE: - gauge_engine_sleep_state = self._gauge_cls( - name="vllm:engine_sleep_state", - documentation=( - "Engine sleep state; awake = 0 means engine is sleeping; " - "awake = 1 means engine is awake; " - "weights_offloaded = 1 means sleep level 1; " - "discard_all = 1 means sleep level 2." - ), - labelnames=labelnames + ["sleep_state"], - multiprocess_mode="mostrecent", - ) - self.gauge_engine_sleep_state = {} - sleep_state = ["awake", "weights_offloaded", "discard_all"] + gauge_engine_sleep_state = self._gauge_cls( + name="vllm:engine_sleep_state", + documentation=( + "Engine sleep state; awake = 0 means engine is sleeping; " + "awake = 1 means engine is awake; " + "weights_offloaded = 1 means sleep level 1; " + "discard_all = 1 means sleep level 2." + ), + labelnames=labelnames + ["sleep_state"], + multiprocess_mode="mostrecent", + ) - for s in sleep_state: - self.gauge_engine_sleep_state[s] = { - idx: gauge_engine_sleep_state.labels( - engine=idx, model_name=model_name, sleep_state=s - ) - for idx in engine_indexes - } + self.gauge_engine_sleep_state = {} + sleep_state = ["awake", "weights_offloaded", "discard_all"] - # Setting default values - self.record_sleep_state() + for s in sleep_state: + self.gauge_engine_sleep_state[s] = { + idx: gauge_engine_sleep_state.labels( + engine=idx, model_name=model_name, sleep_state=s + ) + for idx in engine_indexes + } + + # Setting default values + self.record_sleep_state() # GPU cache # @@ -1052,9 +1051,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() def record_sleep_state(self, sleep: int = 0, level: int = 0): - if not envs.VLLM_SERVER_DEV_MODE: - return - awake = 1 discard_all = 0 weights_offloaded = 0 From 7e4be741044bfead91afc418100ff9a4d804bf7f Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 4 Nov 2025 01:05:55 -0500 Subject: [PATCH 041/231] [Bug] Batch invariant: Fix flash attn MLA `RuntimeError: scheduler_metadata must have shape (metadata_size)` (#27884) --- vllm/model_executor/layers/batch_invariant.py | 2 ++ vllm/v1/attention/backends/mla/flashattn_mla.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 39e77b935d3d5..0234f228d700a 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import functools import os from collections import namedtuple from collections.abc import Callable @@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) +@functools.cache def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT" is_overridden = False diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a6aac701b784b..6baf45efccb54 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits + if vllm_is_batch_invariant(): + max_num_splits = 1 + scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), cu_query_lens=query_start_loc_device, @@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if vllm_is_batch_invariant(): - max_num_splits = 1 - metadata = FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, From f32cbc9a0c456966df300076a3a9f2889151b024 Mon Sep 17 00:00:00 2001 From: xiangze-arm Date: Tue, 4 Nov 2025 14:33:23 +0800 Subject: [PATCH 042/231] [CPU]Improve dynamic 4bit moe performance (#27240) Signed-off-by: Zhang Xiangze --- csrc/moe/dynamic_4bit_int_moe_cpu.cpp | 33 ++++++++++----------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index 1d06fc6b5b0a0..df47bb8dd1d7d 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu( const int64_t g_eff_13 = (group_size != -1) ? group_size : H; const int64_t g_eff_2 = (group_size != -1) ? group_size : I; - // Per-expert outputs filled in parallel - std::vector y_list(E); - y_list.resize(E); + auto X_all = x_c.index_select(/*dim=*/0, expert_tokens); + if (apply_router_weight_on_input) { + X_all = X_all.mul(expert_gates.unsqueeze(1)); + } + auto Y_all = at::empty({offsets[E], H}, x_c.options()); at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + c10::InferenceMode guard; for (int64_t e = e_begin; e < e_end; ++e) { const int64_t te = counts[e]; if (te == 0) { - y_list[e] = at::empty({0, H}, x_c.options()); continue; } const int64_t start = offsets[e]; - auto sel_tokens = - expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - auto gates_e = - expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - - auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); - - if (apply_router_weight_on_input) { - x_e = x_e.mul(gates_e.unsqueeze(1)); - } + auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); auto w2_e = w2_packed.select(/*dim=*/0, e); @@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu( // W2 auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); - if (!apply_router_weight_on_input) { - y = y.mul(gates_e.unsqueeze(1)); - } - // Store per-expert result - y_list[e] = y; + Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y); } }); - // Concatenate all expert outputs to match expert_tokens order - auto Y_all = at::cat(y_list, /*dim=*/0); + if (!apply_router_weight_on_input) { + Y_all = Y_all.mul(expert_gates.unsqueeze(1)); + } + auto out = at::zeros({T, H}, x.options()); out = at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); From 2f84ae1f27eb628a195ee9ccd4e884baeb451d1c Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Mon, 3 Nov 2025 22:36:40 -0800 Subject: [PATCH 043/231] [CI/Build] Update LM Eval Version in AMD CI (#27944) Signed-off-by: zhewenli --- docker/Dockerfile.rocm | 1 - requirements/rocm-test.txt | 15 +++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index adb0879f20d47..06d229f315bdc 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -75,7 +75,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 541fa1e267cb0..432e11977872d 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -4,7 +4,7 @@ tblib==3.1.0 bm25s==0.2.13 pystemmer==3.0.0 -# entrypoints test +# Entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai audioread==3.0.1 cffi==1.17.1 @@ -17,11 +17,11 @@ soundfile==0.13.1 soxr==0.5.0.post1 librosa==0.10.2.post1 -# entrypoints test +# Entrypoints test #vllm[video] # required by entrypoints/openai/test_video.py decord==0.6.0 -# entrypoints test +# Entrypoints test #sentence-transformers # required by entrypoints/openai/test_score.py sentence-transformers==3.4.1 @@ -32,7 +32,10 @@ matplotlib==3.10.3 blobfile==3.0.0 # Required for openai schema test. -schemathesis==3.39.15 +schemathesis==3.39.15 -# required for mteb test -mteb[bm25s]>=1.38.11, <2 +# Required for mteb test +mteb[bm25s]>=1.38.11, <2 + +# Required for eval tests +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d From 58279c60b52c7e6e286799a313416949f43aeefe Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 4 Nov 2025 07:00:49 +0000 Subject: [PATCH 044/231] [KV Connector] Make KVCacheConfig an explicit constructor argument (#27887) Signed-off-by: Mark McLoughlin --- .../unit/test_backwards_compatibility.py | 275 ++++++++++++++++++ tests/v1/kv_connector/unit/utils.py | 2 +- .../kv_transfer/kv_connector/factory.py | 41 ++- .../kv_transfer/kv_connector/v1/base.py | 16 +- .../kv_connector/v1/decode_bench_connector.py | 12 +- .../kv_connector/v1/lmcache_connector.py | 12 +- .../kv_connector/v1/multi_connector.py | 14 +- .../kv_connector/v1/nixl_connector.py | 12 +- .../kv_connector/v1/offloading_connector.py | 10 +- .../kv_connector/v1/p2p/p2p_nccl_connector.py | 16 +- .../v1/shared_storage_connector.py | 16 +- .../kv_transfer/kv_transfer_state.py | 11 +- vllm/v1/core/sched/scheduler.py | 12 +- vllm/v1/worker/gpu_worker.py | 4 +- 14 files changed, 410 insertions(+), 43 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_backwards_compatibility.py diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py new file mode 100644 index 0000000000000..f51001a6ec12a --- /dev/null +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for backwards compatibility with external KV connector implementations. + +This test ensures that external connectors (loaded via kv_connector_module_path) +implemented with the old signature continue to work: +- Old signature: __init__(self, vllm_config, role) +- New signature: __init__(self, vllm_config, role, kv_cache_config) +""" + +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.v1.core.sched.output import SchedulerOutput + +from .utils import create_scheduler, create_vllm_config + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + + +class OldStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the old signature with 2 required arguments. + This simulates external connectors that haven't been updated yet. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + # Old-style call to super().__init__ with only 2 arguments + super().__init__(vllm_config=vllm_config, role=role) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +class NewStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the new signature with 3 required arguments. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + # New-style call to super().__init__ with all 3 arguments + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_old_signature_factory_instantiation(role): + """ + Test that external connectors with old signature (2 required args) loaded + via kv_connector_module_path are correctly instantiated with backwards + compatibility support. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, OldStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is None + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_new_signature_factory_instantiation(role): + """ + Test that external connectors with new signature (3 required args) loaded + via kv_connector_module_path are correctly instantiated. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, NewStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_old_signature_super_init(role): + """ + Test that old-style connectors can call super().__init__() without + kv_cache_config parameter. + """ + vllm_config = create_vllm_config() + + connector = OldStyleTestConnector(vllm_config, role) + + assert connector is not None + assert connector.role == role + assert connector._kv_cache_config is None + + +def test_old_signature_super_init_with_kwargs(): + """ + Test that old-style connectors can call super().__init__() with keyword + arguments in different orders. + """ + vllm_config = create_vllm_config() + + # Test with vllm_config= and role= kwargs + connector1 = OldStyleTestConnector( + vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER + ) + assert connector1 is not None + assert connector1._kv_cache_config is None + + # Test with role= and vllm_config= in reversed order + connector2 = OldStyleTestConnector( + role=KVConnectorRole.WORKER, vllm_config=vllm_config + ) + assert connector2 is not None + assert connector2._kv_cache_config is None + + +def test_internal_connector_uses_new_signature(): + """ + Test that internal connectors (registered in factory) always use the new + signature and get kv_cache_config. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + ) + + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + + assert connector is not None + assert isinstance(connector, SharedStorageConnector) + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +def test_signature_detection_with_mocking(): + """ + Test that the factory correctly applies compat_sig flag returned from + _get_connector_class_with_compat. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + # Mock _get_connector_class_with_compat to return old-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(OldStyleTestConnector, True), + ): + old_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert old_connector is not None + assert isinstance(old_connector, OldStyleTestConnector) + assert old_connector._kv_cache_config is None + + # Mock _get_connector_class_with_compat to return new-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(NewStyleTestConnector, False), + ): + new_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert new_connector is not None + assert isinstance(new_connector, NewStyleTestConnector) + assert new_connector._kv_cache_config is not None + assert new_connector._kv_cache_config == kv_cache_config diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 46ea46e53084e..c1c0e13f77539 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -254,7 +254,7 @@ def create_model_runner_output( class TestSharedStorageConnector(SharedStorageConnector): - def __init__(self, config: VllmConfig, role): + def __init__(self, config: VllmConfig, role, kv_cache_config): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self._connector = SharedStorageConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index c64996f13cd5d..8d14200c52407 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,10 +3,9 @@ import importlib from collections.abc import Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import vllm.envs as envs -from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, @@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( supports_hma, ) from vllm.logger import init_logger +from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig + from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -41,8 +43,9 @@ class KVConnectorFactory: @classmethod def create_connector( cls, - config: VllmConfig, + config: "VllmConfig", role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: raise ValueError( @@ -53,7 +56,9 @@ class KVConnectorFactory: kv_transfer_config = config.kv_transfer_config if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") - connector_cls = cls.get_connector_class(kv_transfer_config) + connector_cls, compat_sig = cls._get_connector_class_with_compat( + kv_transfer_config + ) # check if the connector supports HMA hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager @@ -76,7 +81,12 @@ class KVConnectorFactory: # - Co-locate with worker process # - Should only be used inside the forward context & attention layer # We build separately to enforce strict separation - return connector_cls(config, role) + if compat_sig: + # Old signature: __init__(self, vllm_config, role) + return connector_cls(config, role) + else: + # New signature: __init__(self, vllm_config, role, kv_cache_config) + return connector_cls(config, role, kv_cache_config) @classmethod def get_connector_class_by_name( @@ -97,13 +107,13 @@ class KVConnectorFactory: return cls._registry[connector_name]() @classmethod - def get_connector_class( + def _get_connector_class_with_compat( cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + ) -> tuple[type[KVConnectorBaseType], bool]: connector_name = kv_transfer_config.kv_connector if connector_name is None: raise ValueError("Connector name is not set in KVTransferConfig") + compat_sig = False if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() else: @@ -118,6 +128,21 @@ class KVConnectorFactory: f"Class {connector_name} not found in {connector_module_path}" ) from e connector_cls = cast(type[KVConnectorBaseType], connector_cls) + if not supports_kw(connector_cls, "kv_cache_config"): + compat_sig = True + logger.warning( + "Connector %s uses deprecated signature with 2 required arguments. " + "Please update to include kv_cache_config as the second argument.", + connector_cls.__name__, + ) + return connector_cls, compat_sig + + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config) return connector_cls diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index cb9f208a839f2..354aa9a87183d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction @@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design." @@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC): self._kv_transfer_config = vllm_config.kv_transfer_config else: raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") + self._kv_cache_config = kv_cache_config + if self._kv_cache_config is None: + logger.warning( + "KVConnectorBase_V1 initialized without kv_cache_config. " + "This is deprecated - please update your connector to accept " + "kv_cache_config as the third constructor argument and pass it " + "to super().__init__()." + ) self._role = role @property diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py index ca251cd0c6ebd..9cd7d93c92fa3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -32,7 +32,7 @@ Usage: """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch @@ -50,6 +50,7 @@ if TYPE_CHECKING: from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1): testing of the decoder with larger input sequence lengths. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) self.connector_scheduler: DecodeBenchConnectorScheduler | None = None self.connector_worker: DecodeBenchConnectorWorker | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 7232d947030cb..575ab468be566 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -20,14 +20,22 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) class LMCacheConnectorV1(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) assert vllm_config.kv_transfer_config is not None use_native = vllm_config.kv_transfer_config.get_from_extra_config( "use_native", False diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d56f30bd11e5b..d7bbf02c83677 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1): - Save to all connectors. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] for connector_cls, temp_config in self._get_connector_classes_and_configs( vllm_config ): - self._connectors.append(connector_cls(temp_config, role)) + self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4651cedbc7dfa..ff9770b72bd38 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import msgspec import numpy as np @@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request Transfer = tuple[int, float] # (xfer_handle, start_time) @@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata): class NixlConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 7567c7fae5789..582e42cc466ae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -21,6 +21,7 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.abstract import OffloadingManager from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.mediums import GPULoadStoreSpec @@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) spec = OffloadingSpecFactory.create_spec(vllm_config) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 780dd12fccda3..a124a0d519db8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import regex as re import torch @@ -25,6 +25,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata): class P2pNcclConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} self.is_producer = self._kv_transfer_config.is_kv_producer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9c230d7d0d2f4..016d1d45b3593 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch @@ -22,6 +22,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1): # It does extra work which will overwrite the existing prefix-cache in GPU # - to remove the overhead, need to add some "mask" in the ReqMeta class - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} self._storage_path = self._kv_transfer_config.get_from_extra_config( diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index cabfc10e7f942..7501f0b373d46 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.v1.kv_cache_interface import KVCacheConfig _KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None @@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo return isinstance(connector, KVConnectorBase_V1) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: +def ensure_kv_transfer_initialized( + vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None +) -> None: """ Initialize KV cache transfer parallel group. """ @@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, role=KVConnectorRole.WORKER + config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, ) else: raise ValueError("V0 is no longer supported") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f51744eb2640b..aeb9869c52813 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import itertools import time from collections import defaultdict @@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface): assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" ) - - connector_vllm_config = copy.copy(self.vllm_config) - - # We're dynamically inserting a kv_cache_config variable into the - # connector_vllm_config. This is distinct from the cache_config - # that is already in there. - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined] self.connector = KVConnectorFactory.create_connector( - config=connector_vllm_config, role=KVConnectorRole.SCHEDULER + config=self.vllm_config, + role=KVConnectorRole.SCHEDULER, + kv_cache_config=self.kv_cache_config, ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c2bf1419bebd7..f3fe202cec062 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -380,9 +380,7 @@ class Worker(WorkerBase): # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, # because `initialize_kv_cache` will inject kv cache groups not # related to kv cache connector (e.g. kv cache sharing layers). - connector_vllm_config = copy.copy(self.vllm_config) - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) - ensure_kv_transfer_initialized(connector_vllm_config) + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator From 43a6acfb7de8c7ad839d41bc2109fafe692b77ba Mon Sep 17 00:00:00 2001 From: CSWYF3634076 Date: Tue, 4 Nov 2025 15:16:46 +0800 Subject: [PATCH 045/231] [Model] fix ernie45 reasoning_parser (#27973) Signed-off-by: wangyafeng --- vllm/reasoning/ernie45_reasoning_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/reasoning/ernie45_reasoning_parser.py b/vllm/reasoning/ernie45_reasoning_parser.py index f9d4a30398cfd..8dfbcc0ce46bf 100644 --- a/vllm/reasoning/ernie45_reasoning_parser.py +++ b/vllm/reasoning/ernie45_reasoning_parser.py @@ -36,8 +36,8 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser): """The token that ends reasoning content.""" return "" - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) if not self.model_tokenizer: raise ValueError( From 53f6e81dfd9cdba797ddade119a5e33389a35957 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Mon, 3 Nov 2025 23:20:50 -0800 Subject: [PATCH 046/231] [CI/Build] Fix OpenAI API correctness on AMD CI (#28022) Signed-off-by: zhewenli --- .buildkite/test-amd.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index c023457fb03e4..5abf6122a5c39 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -629,15 +629,16 @@ steps: - label: OpenAI API correctness # 22min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ - vllm/model_executor/models/whisper.py - commands: # LMEval+Transcription WER check - - pytest -s entrypoints/openai/correctness/ + commands: # LMEval + # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 + - pytest -s entrypoints/openai/correctness/ --ignore entrypoints/openai/correctness/test_transcription_api_correctness.py - label: OpenAI-Compatible Tool Use # 23 min timeout_in_minutes: 35 From 4022a9d279d09efe1b8a36ff3531bf1d4c8f08ca Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 4 Nov 2025 02:56:21 -0500 Subject: [PATCH 047/231] [BugFix][Performance] Restore flashinfer autotuning for all scenarios (#27904) --- tests/quantization/test_blackwell_moe.py | 16 ++--------- .../layers/fused_moe/trtllm_moe.py | 11 ++++++-- .../layers/quantization/mxfp4.py | 4 +-- vllm/model_executor/warmup/kernel_warmup.py | 27 +------------------ 4 files changed, 14 insertions(+), 44 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 3cae6f46147bf..8dd4551ff4b96 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -172,21 +172,9 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) -def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") - monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") +def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): can_initialize( "openai/gpt-oss-20b", - extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], - hf_overrides=HF_OVERRIDE_TEXT, - ) - - -def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") - monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") - can_initialize( - "openai/gpt-oss-20b", - extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--enforce-eager"], ) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index e305483eb17db..132d35e65aba8 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -127,10 +127,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "routing_method_type": 1, "do_finalize": True, "output": output, - "tune_max_num_tokens": self.max_capture_size, + "tune_max_num_tokens": max(self.max_capture_size, 1), } from flashinfer import trtllm_fp4_block_scale_routed_moe - trtllm_fp4_block_scale_routed_moe(**kwargs) + from vllm.utils.flashinfer import autotune + + with autotune(False): + # Enable autotune when, + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is + # resolved. + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 597ee1b6bafe1..bf34ec0f38996 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1047,7 +1047,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize - tune_max_num_tokens=self.max_capture_size, + tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output elif ( @@ -1122,7 +1122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tp_rank=self.moe.tp_rank, ep_size=self.moe.ep_size, ep_rank=self.moe.ep_rank, - tune_max_num_tokens=self.max_capture_size, + tune_max_num_tokens=max(self.max_capture_size, 1), **extra_kwargs, ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index ffa3bc8f021ef..28792338f036f 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING import torch import vllm.envs as envs -from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform @@ -25,26 +24,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool: - """ - Record known issues with vllm + flashinfer autotune here. Return True if - and only if flashinfer autotune will run through without issues. - """ - is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or ( - vllm_config.parallel_config.tensor_parallel_size > 1 - ) - is_fi_mxfp4_backend = ( - envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS - ) or ( - current_platform.is_cuda() and current_platform.is_device_capability(100) - ) # on >=sm100, default mxfp4 backend is flashinfer - is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - - return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager) - - def kernel_warmup(worker: "Worker"): # Deep GEMM warmup do_deep_gemm_warmup = ( @@ -58,11 +37,7 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs - if ( - has_flashinfer() - and current_platform.has_device_capability(90) - and flashinfer_autotune_supported(worker.vllm_config) - ): + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) # FlashInfer attention warmup From 2ec401bc39daf0c8daa7f7c6bffe4f5e15cb7c79 Mon Sep 17 00:00:00 2001 From: yugong333 Date: Tue, 4 Nov 2025 02:27:35 -0800 Subject: [PATCH 048/231] Load tuned fused_moe_lora shrink and expand kernel configs separately (#27435) Signed-off-by: Yu Gong Co-authored-by: Jee Jee Li --- benchmarks/kernels/benchmark_lora.py | 478 ++++++++++++++++-- tests/lora/test_fused_moe_lora_kernel.py | 11 + vllm/lora/layers/fused_moe.py | 103 +++- vllm/lora/ops/triton_ops/README_TUNING.md | 11 +- vllm/lora/ops/triton_ops/__init__.py | 9 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 356 ++++++++++--- vllm/lora/ops/triton_ops/utils.py | 43 +- vllm/lora/punica_wrapper/punica_base.py | 3 +- vllm/lora/punica_wrapper/punica_gpu.py | 22 +- 9 files changed, 911 insertions(+), 125 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index bf1512268fe0b..6715c9b548aa1 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -19,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.triton_utils import HAS_TRITON +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs +from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora + LoRAKernelMeta, + fused_moe_lora_expand, + fused_moe_lora_shrink, + lora_expand, + lora_shrink, + ) + from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + _LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora + ) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT - +from vllm import _custom_ops as ops from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import round_up DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] @@ -59,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4] DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] +DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k +DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts # Utilities @@ -191,6 +204,11 @@ class OpType(Enum): LORA_SHRINK = auto() LORA_EXPAND = auto() + ## Adding support for fused moe lora + FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink + FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand + FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink + FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand @staticmethod def from_str(s: str) -> "OpType": @@ -198,6 +216,15 @@ class OpType(Enum): return OpType.LORA_SHRINK if s.lower() == "lora_expand": return OpType.LORA_EXPAND + # Adding support for fused moe lora, both in gate_up and down + if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink + return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand + return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink + return OpType.FUSED_MOE_LORA_DOWN_SHRINK + if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand + return OpType.FUSED_MOE_LORA_DOWN_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: @@ -206,19 +233,56 @@ class OpType(Enum): def is_expand_fn(self) -> bool: return self in [OpType.LORA_EXPAND] + def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_gate_up_fn( + self, + ) -> bool: ## adding for fused MoE LoRA Gate/Up + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + ] + + def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down + return self in [ + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_shrink_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ] + + def is_fused_moe_lora_expand_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + def num_slices(self) -> list[int]: + if self.is_fused_moe_lora_gate_up_fn(): + return [2] + elif self.is_fused_moe_lora_down_fn(): + return [1] return [1, 2, 3] def mkn( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length - if self.is_shrink_fn(): + if self.is_shrink_fn() or self.is_fused_moe_lora_fn(): m = num_tokens k = hidden_size n = lora_rank - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): m = num_tokens k = lora_rank n = hidden_size @@ -232,9 +296,36 @@ class OpType(Enum): """ if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): return torch.float32, op_dtype, op_dtype + else: + assert self.is_fused_moe_lora_fn() + return op_dtype, op_dtype, op_dtype + + def matmul_shapes_fused_moe_lora( + self, + m: int, + n: int, + k: int, + num_loras: int, + num_slices: int, + top_k_num: int, + num_experts: int, + ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) + return (input_shape, weight_shape, output_shape) def matmul_shapes( self, @@ -244,6 +335,8 @@ class OpType(Enum): lora_rank: int, num_loras: int, num_slices: int, + top_k_num: int | None = None, + num_experts: int | None = None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Given num_slices, return the shapes of the A, B, and C matrices @@ -258,6 +351,16 @@ class OpType(Enum): if self in [OpType.LORA_EXPAND]: # LoRA expand kernels support num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self.is_fused_moe_lora_fn(): + return self.matmul_shapes_fused_moe_lora( + m, + k, + n, + num_loras, + num_slices, + top_k_num, + num_experts, + ) raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: @@ -265,6 +368,16 @@ class OpType(Enum): return lora_shrink if self == OpType.LORA_EXPAND: return lora_expand + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ]: + return fused_moe_lora_shrink + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ]: + return fused_moe_lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -318,6 +431,8 @@ class BenchmarkContext: sort_by_lora_id: bool dtype: torch.dtype seq_length: int | None = None + num_experts: int | None = None # num_experts for MoE based ops + top_k_num: int | None = None # top_k for MoE based ops num_slices: int | None = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": @@ -373,6 +488,11 @@ class BenchmarkTensors: f"{dtype_to_str(self.output.dtype)}" ) + def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType): + return ( + size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size + ) + @staticmethod def make( ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" @@ -385,6 +505,8 @@ class BenchmarkTensors: ctx.lora_rank, ctx.num_loras, ctx.num_slices, + ctx.top_k_num, + ctx.num_experts, ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) input_tensor, lora_weights, output_tensor = make_rand_tensors( @@ -432,17 +554,27 @@ class BenchmarkTensors: prompt_lora_indices_tensor, ) - def sanity_check(self) -> None: + def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: """ Fails asserts when non-conformality is detected. """ - num_tokens = self.input.shape[-2] + num_tokens = ( + self.input.shape[1] + if op_type.is_fused_moe_lora_expand_fn() + else self.input.shape[-2] + ) # check metadata tensors - assert torch.sum(self.seq_lens) == num_tokens + ## In down shrink case, each token is repeated top_k_num times + assert num_tokens == self.get_num_tokens( + torch.sum(self.seq_lens), ctx.top_k_num, op_type + ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs + ## In down shrink case, each prompt corresponds to top_k_num sequences assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + assert self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) def to_device(self, device: str): """ @@ -471,21 +603,111 @@ class BenchmarkTensors: to_device(field) if field_name != "no_lora_flag_cpu" else field, ) - def metadata(self) -> tuple[int, int, int]: + def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] + num_tokens = self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices - def as_lora_shrink_kwargs(self) -> dict[str, Any]: - self.sanity_check() + def fused_moe_lora_data_prepare( + self, + block_size: int, + token_lora_mapping: torch.Tensor, + ctx: BenchmarkContext, + ): + def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + num_tokens = ctx.batch_size + curr_topk_ids = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + topk_weights = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + + (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( + moe_lora_align_block_size( + topk_ids=curr_topk_ids, + token_lora_mapping=token_lora_mapping, + block_size=block_size, + num_experts=ctx.num_experts, + max_loras=ctx.num_loras, + ) + ) + + sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1) + expert_ids = expert_ids_lora.view(ctx.num_loras, -1) + num_tokens_post_padded = num_tokens_post_padded_lora + return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) + + def as_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -520,11 +742,13 @@ class BenchmarkTensors: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - self.sanity_check() + def as_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -561,18 +785,173 @@ class BenchmarkTensors: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def bench_fn_kwargs( - self, op_type: OpType, add_inputs: bool | None = None + def as_fused_moe_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType ) -> dict[str, Any]: - if op_type.is_shrink_fn(): + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_tokens, hidden_size] for gate_up + # Expected input shape : [top_k_num * num_tokens, hidden_size] for down + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size] + assert len(lw_shape) == 4 + assert lw_shape[-1] == hidden_size + lora_rank = lw_shape[-2] + # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(o_shape) == 4 + assert ( + o_shape + == (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank) + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + ) + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "qcurr_hidden_states": self.input, + "lora_a_stacked": self.lora_weights_lst, + "a_intermediate_cache1": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "shrink_block_size_m": kernel_config["BLOCK_SIZE_M"], + "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], + "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], + "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], + "shrink_num_warps": kernel_config["NUM_WARPS"], + "shrink_num_stages": kernel_config["NUM_STAGES"], + "shrink_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def as_fused_moe_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + + # Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(i_shape) == 4 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[-1] + # Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank] + assert len(lw_shape) == 4 + assert lw_shape[-1] == lora_rank + hidden_size = lw_shape[-2] + # Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices] + assert len(o_shape) == 3 + assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices) + + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "a_intermediate_cache1": self.input, + "lora_b_stacked": self.lora_weights_lst, + "output": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "max_lora_rank": lora_rank, + "w1_output_dim_size": lw_shape[2], + "expand_block_size_m": kernel_config["BLOCK_SIZE_M"], + "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], + "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], + "expand_group_size_m": kernel_config["GROUP_SIZE_M"], + "expand_num_warps": kernel_config["NUM_WARPS"], + "expand_num_stages": kernel_config["NUM_STAGES"], + "expand_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def bench_fn_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None + ) -> dict[str, Any]: + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert add_inputs is None else: assert add_inputs is not None if op_type == OpType.LORA_SHRINK: - return self.as_lora_shrink_kwargs() + return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: - return self.as_lora_expand_kwargs(add_inputs) + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") def test_correctness( @@ -617,7 +996,7 @@ def bench_optype( test_correctness: bool = False, ) -> TMeasurement: assert arg_pool_size >= 1 - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert expand_fn_add_inputs is None else: assert expand_fn_add_inputs is not None @@ -627,23 +1006,30 @@ def bench_optype( BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) ] for bt in bench_tensors: - bt.sanity_check() + bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: + assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], ( + f"Correctness testing is not supported for {op_type.name}." + ) assert all( - [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + [ + bt.test_correctness(ctx, op_type, expand_fn_add_inputs) + for bt in bench_tensors + ] ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ - bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() + _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) @@ -793,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): # Benchmark bench_op expand_fn_add_inputs = ( - [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + [None] + if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn() + else args.expand_fn_add_inputs ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( @@ -831,12 +1219,22 @@ def as_benchmark_contexts( hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace ) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] - for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + for ( + batch_size, + hidden_size, + lora_rank, + num_loras, + sort_by_lora_id, + top_k_num, + num_experts, + ) in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id, + args.top_k_nums, + args.num_experts, ): ctxs.append( BenchmarkContext( @@ -851,6 +1249,8 @@ def as_benchmark_contexts( seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, + top_k_num=top_k_num, + num_experts=num_experts, # To be filled based on the OpType to benchmark num_slices=None, ) @@ -1012,6 +1412,22 @@ if __name__ == "__main__": ), ) + p.add_argument( + "--top-k-nums", + nargs="+", + type=int, + default=DEFAULT_TOP_K_NUMS, + help="Top-K values for MoE LoRA operations", + ) + + p.add_argument( + "--num-experts", + nargs="+", + type=int, + default=DEFAULT_NUM_EXPERTS, + help="Number of experts for MoE LoRA operations", + ) + parser = FlexibleArgumentParser( description=f""" Benchmark LoRA kernels: diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 318a0e58805d3..91ab4a87c65f8 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel( "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, "SPLIT_K": 1, } @@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel( config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], config["SPLIT_K"], mul_routed_weight, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 7711f5c3208bc..f5a766dd5e45a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -13,6 +13,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import ( _get_config_dtype_str, @@ -39,6 +40,64 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): self.device = base_layer.w2_weight.device self._inject_lora_into_fused_moe() + def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: + normalized_config = {} + for key, value in config.items(): + if key.islower(): + if key.startswith("block_"): + normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper() + else: + normalized_key = key.upper() + else: + normalized_key = key + normalized_config[normalized_key] = value + return normalized_config + + def _get_lora_moe_configs( + self, + op_prefix: str, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + num_slices: int, + M: int, + layer: FusedMoE, + top_k: int, + config_dtype: str, + ): + if envs.VLLM_TUNED_CONFIG_FOLDER: + shrink_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_shrink", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked.shape[-2], + ) + expand_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_expand", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked.shape[-2], + ) + else: # fall back to the default config + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + shrink_config = get_config_func(M) + expand_config = get_config_func(M) + shrink_config = self._normalize_keys(shrink_config) + expand_config = self._normalize_keys(expand_config) + return shrink_config, expand_config + def _inject_lora_into_fused_moe(self): moe_state_dict = {} top_k = self.base_layer.top_k @@ -90,17 +149,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="w13", + lora_a_stacked=self.w1_lora_a_stacked, + lora_b_stacked=self.w1_lora_b_stacked, + num_slices=2, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, ) + # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - config = get_config_func(M) ( sorted_token_ids_lora, expert_ids_lora, @@ -108,7 +169,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ) = self.punica_wrapper.moe_lora_align_block_size( curr_topk_ids, num_tokens, - config["BLOCK_SIZE_M"], + shrink_config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, max_loras, self.adapter_enabled, @@ -138,7 +199,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, ) @@ -164,17 +226,17 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="w2", + lora_a_stacked=self.w2_lora_a_stacked, + lora_b_stacked=self.w2_lora_b_stacked, + num_slices=1, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, ) - config = get_config_func(M) - sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] expert_ids_lora = moe_state_dict["expert_ids_lora"] num_tokens_post_padded_lora = moe_state_dict[ @@ -197,7 +259,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, True, ) diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md index fda95ea71891f..d576e261557a4 100644 --- a/vllm/lora/ops/triton_ops/README_TUNING.md +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -44,8 +44,17 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. +For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`. + +For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`. + +For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`. + +For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`. + The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` ### Json Structure -Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]` +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]` +where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer. diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 436ea4ed00c82..7e8b9a79add39 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora + +from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + fused_moe_lora, + fused_moe_lora_expand, + fused_moe_lora_shrink, +) from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink @@ -11,4 +16,6 @@ __all__ = [ "lora_shrink", "LoRAKernelMeta", "fused_moe_lora", + "fused_moe_lora_shrink", + "fused_moe_lora_expand", ] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 539605c7c534a..8f85f926aa4f1 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -176,88 +176,50 @@ def _fused_moe_lora_kernel( @torch.inference_mode() -def _fused_moe_lora( - output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) +def _fused_moe_lora_shrink( + a_intermediate_cache1: torch.Tensor, + # (num_slices, num_tokens, top_k_num, max_lora_rank) qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) lora_a_stacked: list[ torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] - lora_b_stacked: list[ - torch.Tensor - ], # [(max_loras, num_experts, N, max_lora_rank,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) sorted_token_ids: torch.Tensor, # (max_loras, _) expert_ids: torch.Tensor, # (max_loras, _ ,) num_tokens_post_padded: torch.Tensor, # (max_loras, ) - max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, block_size_m: int, block_size_n: int, block_size_k: int, group_size_m: int, + num_warps: int, + num_stages: int, split_k: int, mul_routed_weight: bool = False, ) -> None: - assert len(lora_a_stacked) == len(lora_b_stacked) > 0 - assert ( - sorted_token_ids.dim() - == expert_ids.dim() - == topk_weights.dim() - == qcurr_hidden_states.dim() - == 2 - ) - assert ( - sorted_token_ids.shape[0] - == expert_ids.shape[0] - == num_tokens_post_padded.shape[0] - ) - assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] - assert output.shape[0] == topk_weights.shape[0] - assert top_k_num == topk_weights.shape[1] + w1_lora_a_stacked = lora_a_stacked[0] - for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked): - assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype - assert lora_a.dtype in [torch.float16, torch.bfloat16] - - device = qcurr_hidden_states.device - num_slices = len(lora_a_stacked) - - config = { + shrink_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, "SPLIT_K": split_k, } - w1_lora_a_stacked = lora_a_stacked[0] - w1_lora_b_stacked = lora_b_stacked[0] - num_experts = lora_a_stacked[0].shape[1] - - N = max_lora_rank - M = topk_weights.shape[0] - EM = sorted_token_ids.shape[1] - K = qcurr_hidden_states.shape[1] - num_tokens = M * top_k_num - w1_output_dim_size = w1_lora_b_stacked.shape[2] - - lora_intermediate_cache1 = torch.zeros( - (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), - dtype=output.dtype, - device=device, - ) - - # slices - a_intermediate_size = num_slices * M * top_k_num * max_lora_rank - a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view( - num_slices, M, top_k_num, max_lora_rank - ) - b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view( - num_slices, M, top_k_num, w1_output_dim_size - ) - b_ptr = _get_ptr(lora_a_stacked, device) grid = lambda META: ( @@ -299,19 +261,70 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, - **config, + **shrink_config, ) + +@torch.inference_mode() +def _fused_moe_lora_expand( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank N = w1_output_dim_size + w1_lora_b_stacked = lora_b_stacked[0] + a_intermediate_cache1 = a_intermediate_cache1.view( -1, a_intermediate_cache1.shape[3] ) - # Set split_k = 1 for expand calls - config["SPLIT_K"] = 1 + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=output.dtype, + device=device, + ) + + expand_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, # Set split_k = 1 for expand calls + } + grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), @@ -348,12 +361,142 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, - **config, + **expand_config, ) for i in range(num_slices): output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + a_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, max_lora_rank), + dtype=output.dtype, + device=device, + ) + + _fused_moe_lora_shrink( + a_intermediate_cache1, + qcurr_hidden_states, + lora_a_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + shrink_block_size_m, + shrink_block_size_n, + shrink_block_size_k, + shrink_group_size_m, + shrink_num_warps, + shrink_num_stages, + shrink_split_k, + mul_routed_weight, + ) + + _fused_moe_lora_expand( + output, + a_intermediate_cache1, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + max_lora_rank, + w1_output_dim_size, + expand_block_size_m, + expand_block_size_n, + expand_block_size_k, + expand_group_size_m, + expand_num_warps, + expand_num_stages, + expand_split_k, + mul_routed_weight, + ) + + def _fused_moe_lora_fake( output: torch.Tensor, qcurr_hidden_states: torch.Tensor, @@ -367,10 +510,84 @@ def _fused_moe_lora_fake( top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_shrink_fake( + a_intermediate_cache1: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, block_size_m: int, block_size_n: int, block_size_k: int, group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_expand_fake( + output: torch.Tensor, + a_intermediate_cache1: torch.Tensor, + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: return @@ -383,7 +600,26 @@ try: mutates_args=["output"], fake_impl=_fused_moe_lora_fake, ) + + direct_register_custom_op( + op_name="fused_moe_lora_shrink", + op_func=_fused_moe_lora_shrink, + mutates_args=["a_intermediate_cache1"], + fake_impl=_fused_moe_lora_shrink_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_expand", + op_func=_fused_moe_lora_expand, + mutates_args=["output"], + fake_impl=_fused_moe_lora_expand_fake, + ) + fused_moe_lora = torch.ops.vllm.fused_moe_lora + fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink + fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand except AttributeError: fused_moe_lora = _fused_moe_lora + fused_moe_lora_shrink = _fused_moe_lora_shrink + fused_moe_lora_expand = _fused_moe_lora_expand diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 368c5037d2e4d..bd413a6db26b8 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -154,13 +154,13 @@ def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: gpu_name = gpu_name.replace("-", "_") config_fname = None - if op_type == "shrink": - config_fname = f"{gpu_name}_{op_type.upper()}.json" - else: - assert op_type == "expand" + # only expand op needs to consider add_inputs + if op_type == "expand": config_fname = ( f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json" ) + else: + config_fname = f"{gpu_name}_{op_type.upper()}.json" config_path = Path(f"{user_defined_config_folder}/{config_fname}") if not config_path.exists(): @@ -186,8 +186,17 @@ def get_lora_op_configs( rank: int, num_slices: int, add_inputs: bool | None = None, + moe_intermediate_size: int | None = None, ) -> dict[str, int | None]: - assert op_type in ["shrink", "expand"] + # Add support for fused_moe_lora ops + assert op_type in [ + "shrink", + "expand", + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", + ] # default config default = {} @@ -203,6 +212,22 @@ def get_lora_op_configs( "num_stages": 2, "max_nreg": None, } + # The default config for fused_moe_lora ops + elif op_type in [ + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", + ]: + default = { + "block_m": 64, + "block_n": 64, + "block_k": 32, + "num_warps": 4, + "num_stages": 3, + "group_size_m": 8, + "split_k": 1, + } else: default = { "block_m": 64, @@ -247,5 +272,13 @@ def get_lora_op_configs( or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] ) + # slice by moe-intermediate-size if applicable + if moe_intermediate_size is not None: + i = moe_intermediate_size + config_data = ( + config_data.get(str(i)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))] + ) + assert config_data is not None return config_data diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index c552412cfd62e..b6186e8561529 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -479,7 +479,8 @@ class PunicaWrapperBase(PunicaWrapperABC): num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 30def90380db1..1bb80e516d3f8 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -367,7 +367,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): @@ -388,10 +389,19 @@ class PunicaWrapperGPU(PunicaWrapperBase): top_k_num, lora_ids, adapter_enabled, - config["BLOCK_SIZE_M"], - config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], - config["GROUP_SIZE_M"], - config.get("SPLIT_K", 1), + shrink_config.get("BLOCK_SIZE_M", 64), + shrink_config.get("BLOCK_SIZE_N", 64), + shrink_config.get("BLOCK_SIZE_K", 32), + shrink_config.get("GROUP_SIZE_M", 8), + shrink_config.get("NUM_WARPS", 4), + shrink_config.get("NUM_STAGES", 3), + shrink_config.get("SPLIT_K", 1), + expand_config.get("BLOCK_SIZE_M", 64), + expand_config.get("BLOCK_SIZE_N", 64), + expand_config.get("BLOCK_SIZE_K", 32), + expand_config.get("GROUP_SIZE_M", 8), + expand_config.get("NUM_WARPS", 4), + expand_config.get("NUM_STAGES", 3), + expand_config.get("SPLIT_K", 1), mul_routed_weight, ) From 03c4c4aa9deb2ad09a95c7997d2e5578c8db68d6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 4 Nov 2025 03:00:57 -0800 Subject: [PATCH 049/231] Support using Int4PreshuffledTensor after loading (#26066) Signed-off-by: Jerry Zhang --- tests/quantization/test_torchao.py | 146 +++++++++++++++++- .../layers/quantization/torchao.py | 66 +++++++- 2 files changed, 208 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index cab198a2a15e2..82413f36e997f 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") -def test_on_the_fly_quant_config_dict_json(vllm_runner): +def test_online_quant_config_dict_json(vllm_runner): """Testing on the fly quantization, load_weights integration point, with config dict serialized to json string """ @@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner): @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") -def test_on_the_fly_quant_config_file(vllm_runner): +def test_online_quant_config_file(vllm_runner): """Testing on the fly quantization, load_weights integration point, with config file """ @@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner): ) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=4) + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch): + """We load a model with Int4Tensor (plain format) linear weights + and verify that the weight is updated to Int4PreshuffledTensor + after loading in vllm + """ + from torchao.quantization import Int4PreshuffledTensor + from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90 + + torch._dynamo.reset() + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev" + # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't + # have meta kernel implemented yet, can remove this flag after that is implemented + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + enforce_eager=True, + ) as llm: + + def has_int4_preshuffled_tensor_weight(model): + return isinstance( + model.model.decoder.layers[0].self_attn.qkv_proj.weight, + Int4PreshuffledTensor, + ) + + def get_weight_attrs(model): + weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight + return [ + weight.requires_grad, + weight.input_dim, + weight.output_dim, + hasattr(weight, "weight_loader"), + ] + + llm_engine = llm.get_llm().llm_engine + has_int4_preshuffled_tensor = any( + llm_engine.apply_model(has_int4_preshuffled_tensor_weight) + ) + weight_attrs = llm_engine.apply_model(get_weight_attrs)[0] + + # making sure we are using Int4PreshuffledTensor on H100 GPU, when + # fbgemm_gpu_genai + # library is installed, otherwise it should be using Int4Tensor + if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + assert has_int4_preshuffled_tensor + else: + assert not has_int4_preshuffled_tensor + + assert weight_attrs == [False, 1, 0, True] + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant( + vllm_runner, monkeypatch +): + """We load a bf16 model and online quantize the model to int4, then verify that + the weights are updated to Int4PreshuffledTensor after online quantization + """ + from torchao.quantization import Int4PreshuffledTensor + from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90 + + torch._dynamo.reset() + model_name = "facebook/opt-125m" + + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import Int4WeightOnlyConfig + + torchao_quant_config = Int4WeightOnlyConfig( + group_size=128, int4_packing_format="plain" + ) + hf_overrides = { + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) + } + + # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't + # have meta kernel implemented yet, can remove this flag after that is implemented + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + hf_overrides=hf_overrides, + enforce_eager=True, + ) as llm: + + def has_int4_preshuffled_tensor_weight(model): + return isinstance( + model.model.decoder.layers[0].self_attn.qkv_proj.weight, + Int4PreshuffledTensor, + ) + + def get_weight_attrs(model): + weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight + return [ + weight.requires_grad, + weight.input_dim, + weight.output_dim, + hasattr(weight, "weight_loader"), + ] + + llm_engine = llm.get_llm().llm_engine + has_int4_preshuffled_tensor = any( + llm_engine.apply_model(has_int4_preshuffled_tensor_weight) + ) + weight_attrs = llm_engine.apply_model(get_weight_attrs)[0] + + # making sure we are using Int4PreshuffledTensor on H100 GPU, when + # fbgemm_gpu_genai + # library is installed, otherwise it should be using Int4Tensor + if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + assert has_int4_preshuffled_tensor + else: + assert not has_int4_preshuffled_tensor + + assert weight_attrs == [False, 1, 0, True] + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + assert output diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index f42c45dae76d2..3fee71e193db5 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib import json +import types from importlib.util import find_spec from typing import Any, Optional @@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) +def _bond_method_to_cls(func, obj): + if hasattr(func, "__self__") or not callable(func): + # If the function is already bound to an instance, return it as is + return func + else: + return types.MethodType(func, obj) + + +def _get_weight_attrs(param): + # record attributes attached to the weight, so we can + # recover later + recorded_weight_attr = {} + for key in param.__dict__: + if hasattr(param, key): + attr = getattr(param, key) + if not callable(attr): + recorded_weight_attr[key] = attr + elif hasattr(attr, "__self__") and param is attr.__self__: + # if attr is a bonded method for an instance, and + # attr.__self__ points to the instance (param) + # we'll record the underlying function object + recorded_weight_attr[key] = attr.__func__ + else: + recorded_weight_attr[key] = attr + return recorded_weight_attr + + +def _restore_weight_attrs(param, recorded_weight_attr): + for attr_name, attr in recorded_weight_attr.items(): + if not hasattr(param, attr_name): + setattr(param, attr_name, _bond_method_to_cls(attr, param)) + + def torchao_version_at_least(torchao_version: str) -> bool: if find_spec("torchao"): try: @@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool: return False +if torchao_version_at_least("0.15.0"): + from torchao.prototype.tensor_conversion.api import ( + convert_to_packed_tensor_based_on_current_hardware, + ) +else: + convert_to_packed_tensor_based_on_current_hardware = lambda t: t + + class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" @@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.quant_config.is_checkpoint_torchao_serialized: + if not hasattr(layer, "weight"): + return + + # record attributes attached to the weight, so we can + # recover later + recorded_weight_attr = _get_weight_attrs(layer.weight) + + layer.weight = Parameter( + convert_to_packed_tensor_based_on_current_hardware(layer.weight), + requires_grad=layer.weight.requires_grad, + ) + + _restore_weight_attrs(layer.weight, recorded_weight_attr) return - # quantize the weight on the fly if the checkpoint is not already + # online quantize the weight if the checkpoint is not already # quantized by torchao + recorded_weight_attr = _get_weight_attrs(layer.weight) + weight = torchao_quantize_param_data( layer.weight, self.quant_config.torchao_config ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + weight = torch.nn.Parameter( + convert_to_packed_tensor_based_on_current_hardware(weight), + weight.requires_grad, + ) + + _restore_weight_attrs(weight, recorded_weight_attr) layer.register_parameter("weight", weight) From 300a2659785fb925f347637d5639d74cc2c5a9f5 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 4 Nov 2025 04:13:35 -0800 Subject: [PATCH 050/231] [Core] Enable StatLogger in LLMEngine (#28020) Signed-off-by: Zhuohan Li --- vllm/v1/engine/llm_engine.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index f44b6b2070d9f..995642a8356fc 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -58,11 +58,6 @@ class LLMEngine: use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: - if stat_loggers is not None: - raise NotImplementedError( - "Passing StatLoggers to LLMEngine is not yet supported." - ) - self.vllm_config = vllm_config self.observability_config = vllm_config.observability_config self.model_config = vllm_config.model_config From 77f8001f533021ece46779f5b7e69edc1d3b514f Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:28:36 +0200 Subject: [PATCH 051/231] [Model][Bugfix] fix pipeline parallelism support for NemotronH (#27968) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/nemotron_h.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 457d3910d0e57..324b63c1732fe 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -20,6 +20,7 @@ import typing from collections.abc import Callable, Iterable +from itertools import islice import torch from torch import nn @@ -549,7 +550,7 @@ class NemotronHModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers" ) - self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) @@ -564,7 +565,7 @@ class NemotronHModel(nn.Module): positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -576,8 +577,7 @@ class NemotronHModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - residual = None - for i, layer in enumerate(self.layers): + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, @@ -633,6 +633,9 @@ class NemotronHModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -678,6 +681,9 @@ class NemotronHModel(nn.Module): if is_expert_weight: continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -792,7 +798,9 @@ class NemotronHForCausalLM( self.unpadded_vocab_size, config.vocab_size ) - self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters if self.model.has_moe: From e4ee6586721cd9e09ac50207cb5e754d7a4a773e Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:59:43 +0200 Subject: [PATCH 052/231] [Model] add optimal triton fused moe configs for NemotronH MoE (#27967) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- benchmarks/kernels/benchmark_moe.py | 1 + ...856,device_name=NVIDIA_H100_80GB_HBM3.json | 147 ++++++++++++++++++ .../E=128,N=1856,device_name=NVIDIA_L40S.json | 147 ++++++++++++++++++ ...928,device_name=NVIDIA_H100_80GB_HBM3.json | 147 ++++++++++++++++++ .../E=128,N=928,device_name=NVIDIA_L40S.json | 147 ++++++++++++++++++ 5 files changed, 589 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index bc6cf83bc21fd..33c83574467cc 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -590,6 +590,7 @@ def main(args: argparse.Namespace): "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", "Glm4MoeForCausalLM", + "NemotronHForCausalLM", ): E = config.n_routed_experts topk = config.num_experts_per_tok diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..ee8a28b833d5a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json new file mode 100644 index 0000000000000..09d3fa584edd8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..fc6454ebfb2fe --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json new file mode 100644 index 0000000000000..48997646d99b6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} From 938772af03ce01590c7e92b0d3fd0a5bdc899d19 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 4 Nov 2025 08:59:45 -0500 Subject: [PATCH 053/231] [Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123) --- .../base_device_communicator.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 261 +++++++++++++----- .../layers/fused_moe/modular_kernel.py | 6 + .../layers/quantization/awq_marlin.py | 2 - .../layers/quantization/bitsandbytes.py | 3 +- .../compressed_tensors_moe.py | 47 ---- .../layers/quantization/experts_int8.py | 2 - .../model_executor/layers/quantization/fp8.py | 35 +-- .../layers/quantization/gguf.py | 2 - .../layers/quantization/gptq_marlin.py | 2 - .../layers/quantization/modelopt.py | 50 +--- .../layers/quantization/moe_wna16.py | 2 - .../layers/quantization/mxfp4.py | 105 +------ .../layers/quantization/quark/quark_moe.py | 53 ++-- .../model_executor/layers/quantization/rtn.py | 2 - .../model_executor/warmup/deep_gemm_warmup.py | 6 +- 16 files changed, 271 insertions(+), 311 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9566dbac7f22f..3a849da70e4cb 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -266,14 +266,14 @@ class DeviceCommunicatorBase: module for module in model.modules() # TODO(bnell): Should use isinstance but can't. Maybe search for - # presence of quant_method.init_prepare_finalize? + # presence of quant_method.maybe_init_modular_kernel? if ( module.__class__.__name__ == "FusedMoE" or module.__class__.__name__ == "SharedFusedMoE" ) ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module) + module.maybe_init_modular_kernel() def dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 55aa2593193ab..118d5fa6b45c4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -117,10 +117,8 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__() - self.moe = moe + self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None - self.fused_experts: FusedMoEModularKernel | None = None - self.topk_indices_dtype = None @abstractmethod def create_weights( @@ -245,9 +243,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): else: return None - # Note: init_prepare_finalize should only be called by - # prepare_communication_buffer_for_model. - def init_prepare_finalize(self, layer: torch.nn.Module): + def maybe_init_modular_kernel( + self, layer: torch.nn.Module + ) -> FusedMoEModularKernel | None: assert self.moe is not None # We must get the quant config here so that the layer is @@ -261,17 +259,14 @@ class FusedMoEMethodBase(QuantizeMethodBase): logger.debug( "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) - assert self.topk_indices_dtype is None - assert self.fused_experts is None, ( - f"Attempt to override experts for {id(self)}!" - ) - self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) - self.fused_experts = FusedMoEModularKernel( + return FusedMoEModularKernel( prepare_finalize, experts, layer.shared_experts, ) + else: + return None def select_gemm_impl( self, @@ -292,8 +287,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): raise NotImplementedError @property - def using_modular_kernel(self) -> bool: - return self.fused_experts is not None + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + @property + def supports_eplb(self) -> bool: + return False + + @property + def allow_inplace(self) -> bool: + return False @abstractmethod def apply( @@ -322,6 +325,138 @@ class FusedMoEMethodBase(QuantizeMethodBase): raise NotImplementedError +@CustomOp.register("modular_fused_moe") +class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): + def __init__( + self, + old_quant_method: FusedMoEMethodBase, + fused_experts: FusedMoEModularKernel, + ): + super().__init__(old_quant_method.moe) + # Find better way to copy attributes? Should we even copy attributes? + # self.__dict__.update(old_quant_method.__dict__) + self.moe_quant_config = old_quant_method.moe_quant_config + self.fused_experts = fused_experts + self.disable_expert_map = getattr( + old_quant_method, + "disable_expert_map", + not fused_experts.supports_expert_map(), + ) + self.old_quant_method = old_quant_method + logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return self.fused_experts.prepare_finalize.topk_indices_dtype() + + @property + def supports_eplb(self) -> bool: + return self.old_quant_method.supports_eplb + + @property + def allow_inplace(self) -> bool: + return self.old_quant_method.allow_inplace + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return self.moe_quant_config + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Is getattr needed? + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + if enable_eplb: + if self.supports_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + else: + raise NotImplementedError( + "EPLB is not supported for " + f"{self.old_quant_method.__class__.__name__}." + ) + + topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + ) + + result = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=self.allow_inplace, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=None if self.disable_expert_map else expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result + + @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" @@ -378,6 +513,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) self.flashinfer_cutlass_moe = None # type: ignore + @property + def supports_eplb(self) -> bool: + return True + + @property + def allow_inplace(self) -> bool: + return True + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: if self.rocm_aiter_moe_enabled: return None @@ -650,7 +793,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) if self.rocm_aiter_moe_enabled: - assert self.fused_experts is None result = self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -671,21 +813,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif self.fused_experts is not None: - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) else: - assert fused_experts is not None result = fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1267,7 +1395,7 @@ class FusedMoE(CustomOp): "Only softmax scoring function is supported for non-grouped topk." ) - moe = FusedMoEConfig( + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, @@ -1279,24 +1407,26 @@ class FusedMoE(CustomOp): is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, ) - self.moe_config: FusedMoEConfig = moe + self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config + def _get_quant_method() -> FusedMoEMethodBase: + """ + Helper method to ensure self.quant_method is never None and + of the proper type. + """ + quant_method = None + if self.quant_config is not None: + quant_method = self.quant_config.get_quant_method(self, prefix) + if quant_method is None: + quant_method = UnquantizedFusedMoEMethod(self.moe_config) + assert isinstance(quant_method, FusedMoEMethodBase) + return quant_method + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method: QuantizeMethodBase | None = None - quant_method = ( - UnquantizedFusedMoEMethod(moe) - if quant_config is None - else quant_config.get_quant_method(self, prefix) - ) - if quant_method is None: - quant_method = UnquantizedFusedMoEMethod(moe) - - assert quant_method is not None - assert isinstance(quant_method, FusedMoEMethodBase) - self.quant_method = quant_method + self.quant_method: FusedMoEMethodBase = _get_quant_method() if not self.moe_config.is_act_and_mul: # Avoid circular import @@ -1305,7 +1435,7 @@ class FusedMoE(CustomOp): ) if not isinstance( - quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) + self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) ): raise NotImplementedError( "is_act_and_mul=False is supported only for unquantized " @@ -1316,20 +1446,18 @@ class FusedMoE(CustomOp): "is_act_and_mul=False is supported only for CUDA for now" ) - if self.enable_eplb: - from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod - - if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)): - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError( - "EPLB is only supported for FP8 quantization for now." - ) + if self.enable_eplb and not self.quant_method.supports_eplb: + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError( + f"EPLB is not supported {self.quant_method.__class__.__name__}. " + "EPLB is only supported for FP8 quantization for now." + ) moe_quant_params = { "num_experts": self.local_num_experts, @@ -1353,6 +1481,15 @@ class FusedMoE(CustomOp): self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None + # Note: maybe_init_modular_kernel should only be called by + # prepare_communication_buffer_for_model. + # This is called after all weight loading and post-processing, so it + # should be safe to swap out the quant_method. + def maybe_init_modular_kernel(self) -> None: + mk = self.quant_method.maybe_init_modular_kernel(self) + if mk is not None: + self.quant_method = FusedMoEModularMethod(self.quant_method, mk) + @property def shared_experts(self) -> torch.nn.Module | None: return None @@ -2167,7 +2304,7 @@ class FusedMoE(CustomOp): """ assert self.quant_method is not None return ( - self.quant_method.fused_experts is not None + isinstance(self.quant_method, FusedMoEModularMethod) and self.quant_method.fused_experts.output_is_reduced() ) @@ -2403,7 +2540,7 @@ class FusedMoE(CustomOp): self.ensure_dp_chunking_init() has_separate_shared_experts = ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + not isinstance(self.quant_method, FusedMoEModularMethod) and self.shared_experts is not None ) @@ -2430,8 +2567,8 @@ class FusedMoE(CustomOp): hidden_states, router_logits, has_separate_shared_experts ) - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 and not self.quant_method.using_modular_kernel + do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( + self.quant_method, FusedMoEModularMethod ) # If there are shared experts but we are not using a modular kernel, the diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3b5916f8ccaf8..b5fa2c71bec58 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -707,6 +707,12 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}" ) + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps. + """ + return self.fused_experts.supports_expert_map() + def output_is_reduced(self) -> bool: """ Indicates whether or not the output of fused MoE kernel diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index daf7422963f3c..3e1f87b59a34d 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -617,8 +617,6 @@ class AWQMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index ccd9b311cc932..e5a741e639ad9 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -518,12 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `BitsAndBytesMoEMethod` yet." ) + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bf38c15b47013..d95d49eddfe3a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -462,12 +462,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin. - # if self.use_marlin: - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -488,24 +483,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): workspace=layer.workspace, ) - elif self.fused_experts is not None: - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight - ), "Flashinfer CUTLASS Fused MoE not applicable!" - - return self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 @@ -1066,13 +1043,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - # - # Note: the order here is important. self.fused_experts can override - # cutlass fp8 or fused_experts but not marlin or rocm. - # if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -1098,7 +1070,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert per_act_token == per_channel_quant assert self.moe_quant_config is not None - assert self.fused_experts is None return rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1111,18 +1082,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): quant_config=self.moe_quant_config, ) - elif self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - ) - # cutlass path elif self.use_cutlass: assert self.moe_quant_config is not None @@ -1318,8 +1277,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." @@ -1636,8 +1593,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet." @@ -1901,8 +1856,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 754608af97c6b..5241f9a2301be 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -158,8 +158,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `ExpertsInt8MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f82eccb88ce09..03eca199d536d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -703,9 +703,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None - - self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore - self.fp8_backend = get_fp8_moe_backend(self.block_quant) self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN @@ -1181,6 +1178,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): block_shape=self.weight_block_size, ) + @property + def supports_eplb(self) -> bool: + return True + + @property + def allow_inplace(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, @@ -1210,10 +1215,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if ( - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - and self.fused_experts is None - ): + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) @@ -1290,10 +1292,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_fused_shared_experts=layer.num_fused_shared_experts, ) - # - # Note: the order of checks is important since self.fused_experts - # can override fused_experts or cutlass but not rocm or marlin. - # topk_weights, topk_ids, zero_expert_result = select_result if self.rocm_aiter_moe_enabled: @@ -1301,7 +1299,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): rocm_aiter_fused_experts, ) - assert self.fused_experts is None result = rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1315,7 +1312,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - assert self.fused_experts is None result = fused_marlin_moe( x, layer.w13_weight, @@ -1333,19 +1329,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map=expert_map, workspace=layer.workspace, ) - elif self.fused_experts: - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not self.block_quant assert not renormalize and custom_routing_function is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 8a914c57a9f7d..caabcd0ca0ee5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -585,8 +585,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.") diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0d5439357fda2..42a569e7770c0 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -742,8 +742,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `GPTQMarlinMoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 37b682984fc35..f61d2a52925d9 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -18,9 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe, -) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, @@ -605,7 +602,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - assert self.fused_experts is None assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) @@ -638,24 +634,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # cutlass or fused_experts. - # - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" @@ -1647,8 +1626,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): from vllm.model_executor.models.llama4 import Llama4MoE - assert self.fused_experts is None - a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( flashinfer.fp4_quantize( @@ -1720,13 +1697,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or - # trtllm. - # if self.use_marlin: - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -1747,23 +1718,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): workspace=layer.workspace, ) - elif self.fused_experts is not None: - assert ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4, ) - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight - ), "Flashinfer CUTLASS Fused MoE not applicable!" + assert self.moe_quant_config is not None - return self.fused_experts( + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec + quant_config=self.moe_quant_config, + inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index b0a268b9950b7..2090c86f78dc8 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -226,7 +226,6 @@ class MoeWNA16Method(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size @@ -381,7 +380,6 @@ class MoeWNA16Method(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None if enable_eplb: raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bf34ec0f38996..7b1600a03d55b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -197,8 +197,6 @@ class Mxfp4Config(QuantizationConfig): class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.topk_indices_dtype = None - self.moe = moe self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size @@ -815,6 +813,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): "EP batched experts format" ) else: + layer.w13_weight = ( + self.w13_weight_triton_tensor + if layer.w13_weight is None + else layer.w13_weight + ) + layer.w2_weight = ( + self.w2_weight_triton_tensor + if layer.w2_weight is None + else layer.w2_weight + ) + assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) + assert self.moe_quant_config is not None if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM @@ -838,71 +848,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" ) - def _route_and_experts( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor: - assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - w13_weight = ( - self.w13_weight_triton_tensor - if layer.w13_weight is None - else layer.w13_weight - ) - w2_weight = ( - self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight - ) - assert all([w is not None for w in [w13_weight, w2_weight]]) - - return self.fused_experts( - hidden_states=x, - w1=w13_weight, - w2=w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + @property + def allow_inplace(self) -> bool: + return True def apply( self, @@ -930,29 +878,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.fused_experts is not None: - return self._route_and_experts( - layer, - x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - expert_load_view, - logical_to_physical_map, - logical_replica_count, - ) - if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a8f4b1b0db68d..8825611051e5d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -310,7 +310,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights, ) @@ -322,17 +321,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - self.fused_experts_func = None - else: - from vllm.model_executor.layers.fused_moe import fused_experts - - self.fused_experts_func = fused_experts def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -369,8 +362,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet." @@ -392,7 +383,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ) if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts_func( + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + return rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -403,7 +398,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): quant_config=self.moe_quant_config, expert_map=expert_map, ) - if self.use_marlin: + elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." return fused_marlin_moe( x, @@ -421,22 +416,22 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): global_num_experts=global_num_experts, expert_map=expert_map, ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts - assert self.fused_experts_func is not None - - return self.fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - quant_config=self.moe_quant_config, - ) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): @@ -601,6 +596,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): block_shape=None, ) + @property + def allow_inplace(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, @@ -624,8 +623,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index e4f7ff8339569..52656263a601b 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -377,8 +377,6 @@ class RTNMoEMethod(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 78cbcd8e5427f..bdcebd498ef01 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -160,8 +160,8 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ): return False - if not isinstance(module.quant_method.fused_experts, FusedMoEModularKernel): - # fused_experts could invoke deep_gemm_moe_fp8 + if not isinstance(module.quant_method, FusedMoEModularMethod): + # modular kernels could invoke deep_gemm_moe_fp8 return True mk: FusedMoEModularKernel = module.quant_method.fused_experts From 5a0a6dfd55e1b9b2b518e0d2e91bd2c1241a7694 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 4 Nov 2025 07:38:16 -0800 Subject: [PATCH 054/231] [BugFix] Fix incorrect preallocated sampled_token_ids tensor size (#28025) Signed-off-by: Nick Hill --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e700c09038e28..177542ed96c8e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -524,7 +524,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( - (self.max_model_len, 1), + (self.max_num_reqs, 1), dtype=torch.int64, device="cpu", pin_memory=self.pin_memory, From 97e3dda84ba79100509fafb58d651bde25e3f32f Mon Sep 17 00:00:00 2001 From: lyrisz <145491716+LyrisZhong@users.noreply.github.com> Date: Tue, 4 Nov 2025 07:49:25 -0800 Subject: [PATCH 055/231] [Perf] SM100 - add swap AB optimization to CUTLASS FP8 GEMM (#27284) Signed-off-by: Faqin Zhong Co-authored-by: Faqin Zhong Co-authored-by: Michael Goin --- .../w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu | 9 +- .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 276 +++++++++++++++--- 2 files changed, 233 insertions(+), 52 deletions(-) diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu index cf2cccc913f62..62aeb927ccdcb 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu @@ -1,6 +1,5 @@ #include "scaled_mm_kernels.hpp" #include "scaled_mm_sm100_fp8_dispatch.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" namespace vllm { @@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales, *bias); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); } else { - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales); } } diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index f876b7d9acd87..c950008b4139a 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -2,6 +2,7 @@ #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" /** * This file defines Gemm kernel configurations for SM100 (fp8) based on the @@ -12,8 +13,88 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, bool swap_ab_ = false> +struct cutlass_3x_gemm_sm100_fp8 { + using ElementAB = ElementAB_; + using ElementC = ElementD_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using EVTCompute = typename Epilogue::EVTCompute; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; + + // Compile-time swap_ab flag + static constexpr bool swap_ab = swap_ab_; + + // ----------------------------------------------------------- + // Layout definitions + // ----------------------------------------------------------- + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + + // ----------------------------------------------------------- + // Collective epilogue (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, + conditional_t, AlignmentCD, + ElementD, conditional_t, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // ----------------------------------------------------------- + // Collective mainloop (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutB_T, AlignmentAB, // Swapped B (as A) + ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B) + ElementAcc, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; + + // ----------------------------------------------------------- + // Kernel definition + // ----------------------------------------------------------- + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + +template struct sm100_fp8_config_default { // M in (256, inf) static_assert(std::is_same()); @@ -22,12 +103,16 @@ struct sm100_fp8_config_default { using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template struct sm100_fp8_config_M256 { // M in (64, 256] static_assert(std::is_same()); @@ -36,44 +121,127 @@ struct sm100_fp8_config_M256 { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template +struct sm100_fp8_config_M64_swap_ab { + // This config is for M in (16, 64] and K >= 4096 + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _64, _256>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; +}; + +template struct sm100_fp8_config_M64 { - // M in (16, 64] + // This config is for M = 64 and K < 4096 (do not enable swap AB in such case) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> -struct sm100_fp8_config_M16 { +template +struct sm100_fp8_config_M16_swap_ab { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + using TileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; }; -template typename Epilogue, +template +void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = + swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC c_stride = cutlass::make_cute_packed_stride( + StrideC{}, + swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args = + swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr, + a_stride} + : typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr, + b_stride}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); @@ -81,55 +249,69 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM16 = - typename sm100_fp8_config_M16::Cutlass3xGemm; + EnableBias>::Cutlass3xGemm; + using Cutlass3xGemmM16SwapAB = + typename sm100_fp8_config_M16_swap_ab::Cutlass3xGemm; + using Cutlass3xGemmM64SwapAB = + typename sm100_fp8_config_M64_swap_ab::Cutlass3xGemm; using Cutlass3xGemmM64 = - typename sm100_fp8_config_M64::Cutlass3xGemm; + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM256 = - typename sm100_fp8_config_M256::Cutlass3xGemm; + typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 + uint32_t const k = a.size(1); - if (mp2 <= 16) { + if (m <= 16) { // m in [1, 16] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } else if (m <= 64) { // m in (16, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 256) { + if (m == 64 && k < 4096) { + // do not enable swap AB + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); + } + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + + } else if (m <= 256) { // m in (64, 256] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (256, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } } -template