From cc079763c59adb8c03305663a5b8857ba85deb1b Mon Sep 17 00:00:00 2001 From: David Ben-David Date: Tue, 11 Nov 2025 09:39:36 +0200 Subject: [PATCH 01/98] [BugFix] Avoid calling KV connector layer APIs when metadata is unset (#28253) Signed-off-by: David Ben-David Co-authored-by: David Ben-David Co-authored-by: Mark McLoughlin --- vllm/attention/layer.py | 4 ++++ vllm/distributed/kv_transfer/kv_connector/v1/base.py | 9 ++++++++- .../kv_transfer/kv_connector/v1/multi_connector.py | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 96272981692c0..acab0529f3520 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector( return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 354aa9a87183d..f85eb414b2222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -204,11 +204,18 @@ class KVConnectorBase_V1(ABC): Returns: ConnectorMetadata: the connector metadata. """ - # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + return self._connector_metadata is not None + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d7bbf02c83677..c9d08e9b78ed0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -171,16 +171,22 @@ class MultiConnector(KVConnectorBase_V1): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. + # + # Note: Call the base class method to ensure metadata is also set on the + # MultiConnector instance itself; otherwise, `has_connector_metadata()` will + # always return False. def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) + super().bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: for c in self._connectors: c.clear_connector_metadata() + super().clear_connector_metadata() def shutdown(self): exception: Exception | None = None From 4fd4b743a23cc6ccbd832f11be12317a8c2f0fbc Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 11 Nov 2025 00:07:24 -0800 Subject: [PATCH 02/98] [Bugfix] Fix max image size for PaddleOCR-VL (#28442) Signed-off-by: Roger Wang --- vllm/model_executor/models/paddleocr_vl.py | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 631475c964c0b..12ae15699e7d2 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -198,23 +198,18 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo): if image_processor is None: image_processor = self.get_image_processor() - do_resize = True hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size - - if do_resize: - resized_height, resized_width = smart_resize( - height=image_height, - width=image_width, - factor=patch_size * merge_size, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - ) - preprocessed_size = ImageSize(width=resized_width, height=resized_height) - else: - preprocessed_size = ImageSize(width=image_width, height=image_height) + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) grid_t = 1 grid_h = preprocessed_size.height // patch_size @@ -227,8 +222,19 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() - image_size = hf_config.vision_config.image_size - return ImageSize(height=image_size, width=image_size) + + # See `smart_resize` for the calculation of the image size. + merge_size = hf_config.vision_config.spatial_merge_size + patch_size = hf_config.vision_config.patch_size + factor = merge_size * patch_size + max_num_tokens = self.get_image_processor().max_pixels // (factor**2) + # Find factors of max_num_tokens close to its square root + # to create a dummy image with a reasonable aspect ratio. + h_patches = int(math.sqrt(max_num_tokens)) + while max_num_tokens % h_patches != 0: + h_patches -= 1 + w_patches = max_num_tokens // h_patches + return ImageSize(height=h_patches * factor, width=w_patches * factor) class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): From 798c7bebca5e3ea48b947af4cc7904a4507ba873 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 11 Nov 2025 00:19:51 -0800 Subject: [PATCH 03/98] [EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369) Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 40 +++++++++++++++------- vllm/distributed/eplb/rebalance_execute.py | 14 +++++--- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index c9d30d6481ab6..e6645e524cc3e 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -12,6 +12,7 @@ Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example on how the EPLB algorithm works. """ +import numpy as np import torch @@ -34,29 +35,44 @@ def balanced_packing( assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs + device = weight.device + if groups_per_pack == 1: pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device + weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm for i in range(num_layers): - pack_weights = [0] * num_packs + pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices[i]: + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) + assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + return pack_index, rank_in_pack @@ -212,7 +228,7 @@ def rebalance_experts( replicas for each logical expert """ num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() + weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8ec3e956401a..5c1efbaf03bab 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace( ) return + old_global_expert_indices_cpu = old_global_expert_indices.cpu() + new_global_expert_indices_cpu = new_global_expert_indices.cpu() + + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + for layer in range(num_moe_layers): - # NOTE(bowen): We need this synchronize to run, but I don't know why. - # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, - old_global_expert_indices[layer].tolist(), - new_global_expert_indices[layer].tolist(), + old_global_expert_indices_cpu[layer].tolist(), + new_global_expert_indices_cpu[layer].tolist(), expert_weights[layer], expert_weights_buffer, ep_group, From f0359fffa434a4fce981389f9dff93a2a4c2b13e Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Tue, 11 Nov 2025 16:24:28 +0800 Subject: [PATCH 04/98] [Bugfix] fix qwen3-next crash (#28202) Signed-off-by: zjy0516 --- vllm/model_executor/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index aa7de5aa5f29c..ddb8693c16e23 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -587,7 +587,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ], validate_data=True, ) From c7991269dd8fe86096a3eee5040e855801ae9665 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Nov 2025 16:45:38 +0800 Subject: [PATCH 05/98] [BugFix] 'DeepseekV2Config' object has no attribute 'use_mla'` (#28387) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/kimi_vl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b54f53931d714..b79bdf8595ca9 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -456,7 +456,11 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - if not config.use_mla: + use_mha = ( + config.model_type == "deepseek" + or config.qk_nope_head_dim + config.qk_rope_head_dim == 0 + ) + if use_mha: stacked_params_mapping += [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), From 9973e6e04ad3e4a6c74c51a2dc87b2d3ddc4837f Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 11 Nov 2025 10:35:10 +0000 Subject: [PATCH 06/98] [Model][Qwen3VL] Slighly speedup `fast_pos_embed_interpolate` (#28434) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fe0124ef3258b..1cd34bf54a35f 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -491,8 +491,8 @@ class Qwen3_VisionTransformer(nn.Module): weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) - weighted_embeds = embeds * weights - combined = weighted_embeds.sum(dim=0) + embeds *= weights + combined = embeds.sum(dim=0) combined = combined.reshape( h // m_size, m_size, w // m_size, m_size, hidden_dim From d381eb967f171ea8824357075b15bf2895619609 Mon Sep 17 00:00:00 2001 From: Ido Segev Date: Tue, 11 Nov 2025 13:06:04 +0200 Subject: [PATCH 07/98] Multi turn benchmark progress bar for synthetic conversation generation (#28394) Signed-off-by: Ido Segev --- benchmarks/multi_turn/bench_dataset.py | 18 +++++++++++++++--- benchmarks/multi_turn/requirements.txt | 3 ++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 2674899d1cc56..8cb8a2f386a97 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -11,6 +11,7 @@ from bench_utils import ( Color, logger, ) +from tqdm import tqdm from transformers import AutoTokenizer # type: ignore # Conversation ID is a string (e.g: "UzTK34D") @@ -417,6 +418,10 @@ def generate_conversations( data = file.read() tokens_in_file = tokenizer.encode(data, add_special_tokens=False) list_of_tokens.extend(tokens_in_file) + logger.info( + f"Loaded {len(tokens_in_file)} tokens from file {filename}, " + f"total tokens so far: {len(list_of_tokens)}" + ) conversations: ConversationsMap = {} conv_id = 0 @@ -449,18 +454,25 @@ def generate_conversations( ) base_offset += common_prefix_tokens - for conv_id in range(args.num_conversations): + for conv_id in tqdm( + range(args.num_conversations), + total=args.num_conversations, + desc="Generating conversations", + unit="conv", + ): # Generate a single conversation messages: MessagesList = [] nturns = turn_count[conv_id] # User prompt token count per turn (with lower limit) - input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int) input_token_count = np.maximum(input_token_count, base_prompt_token_count) # Assistant answer token count per turn (with lower limit) - output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype( + int + ) output_token_count = np.maximum(output_token_count, 1) user_turn = True diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt index f0e1935914a14..bae656a5c5c4b 100644 --- a/benchmarks/multi_turn/requirements.txt +++ b/benchmarks/multi_turn/requirements.txt @@ -2,4 +2,5 @@ numpy>=1.24 pandas>=2.0.0 aiohttp>=3.10 transformers>=4.46 -xlsxwriter>=3.2.1 \ No newline at end of file +xlsxwriter>=3.2.1 +tqdm>=4.66 From 2e78150d24e339bf6420a623cdae655051127d8f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 05:28:28 -0700 Subject: [PATCH 08/98] [CI] Add mergify rules for `nvidia` label (#28417) Signed-off-by: mgoin --- .github/mergify.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/mergify.yml b/.github/mergify.yml index 18d4a2e83144b..997a40e18e588 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -151,6 +151,23 @@ pull_request_rules: add: - gpt-oss +- name: label-nvidia + description: Automatically apply nvidia label + conditions: + - label != stale + - or: + - files~=cuda + - files~=cutlass + - files~=flashinfer + - files~=trtllm + - title~=(?i)NVIDIA + - title~=(?i)CUDA + - title~=(?i)CUTLASS + actions: + label: + add: + - nvidia + - name: label-rocm description: Automatically apply rocm label conditions: From b30dfa03c564ce51c56bf2dd16283f074253c27c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 11 Nov 2025 06:40:44 -0600 Subject: [PATCH 09/98] [Attention] Refactor CUDA attention backend selection logic (#24794) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthew Bonanni Signed-off-by: Matthew Bonanni Co-authored-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 5 + tests/compile/test_fusion_attn.py | 31 +- tests/compile/test_fusions_e2e.py | 24 +- tests/config/test_multimodal_config.py | 6 +- .../attention/test_attention_selector.py | 77 ++-- tests/kernels/attention/test_mha_attn.py | 12 +- tests/models/test_initialization.py | 11 + tests/v1/attention/test_attention_backends.py | 47 ++- tests/v1/attention/test_mla_backends.py | 29 +- tests/v1/attention/utils.py | 10 +- tests/v1/spec_decode/test_eagle.py | 18 +- tests/v1/spec_decode/test_mtp.py | 6 +- tests/v1/spec_decode/test_tree_attention.py | 8 +- tests/v1/worker/test_gpu_model_runner.py | 25 +- vllm/attention/backends/abstract.py | 149 ++++++- vllm/attention/backends/registry.py | 256 ++++++++---- vllm/attention/layer.py | 68 ++-- vllm/attention/selector.py | 124 +++--- vllm/config/cache.py | 10 +- vllm/config/model.py | 8 +- vllm/config/multimodal.py | 32 +- .../kv_connector/v1/nixl_connector.py | 8 +- vllm/engine/arg_utils.py | 4 +- vllm/envs.py | 6 +- vllm/model_executor/models/dots_ocr.py | 37 +- vllm/model_executor/models/ernie45_vl.py | 37 +- vllm/model_executor/models/glm4_1v.py | 35 +- vllm/model_executor/models/keye.py | 24 +- vllm/model_executor/models/ovis2_5.py | 6 +- vllm/model_executor/models/paddleocr_vl.py | 47 +-- vllm/model_executor/models/qwen2_5_vl.py | 42 +- vllm/model_executor/models/qwen2_vl.py | 38 +- .../models/qwen3_omni_moe_thinker.py | 15 +- vllm/model_executor/models/qwen3_vl.py | 26 +- vllm/model_executor/models/siglip2navit.py | 26 +- vllm/model_executor/models/vision.py | 8 +- vllm/platforms/cpu.py | 12 +- vllm/platforms/cuda.py | 366 +++++++++--------- vllm/platforms/interface.py | 42 +- vllm/platforms/rocm.py | 49 ++- vllm/platforms/tpu.py | 15 +- vllm/platforms/xpu.py | 34 +- vllm/v1/attention/backends/cpu_attn.py | 32 +- vllm/v1/attention/backends/flash_attn.py | 71 ++-- vllm/v1/attention/backends/flashinfer.py | 63 +-- vllm/v1/attention/backends/flex_attention.py | 21 +- vllm/v1/attention/backends/mla/common.py | 22 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 16 +- .../attention/backends/mla/flashattn_mla.py | 27 ++ .../attention/backends/mla/flashinfer_mla.py | 26 +- vllm/v1/attention/backends/mla/flashmla.py | 37 +- .../attention/backends/mla/flashmla_sparse.py | 30 +- vllm/v1/attention/backends/mla/indexer.py | 6 +- vllm/v1/attention/backends/mla/triton_mla.py | 10 + vllm/v1/attention/backends/rocm_aiter_fa.py | 25 +- vllm/v1/attention/backends/rocm_attn.py | 10 +- vllm/v1/attention/backends/tree_attn.py | 26 +- vllm/v1/attention/backends/triton_attn.py | 47 ++- vllm/v1/attention/backends/xformers.py | 26 +- vllm/v1/spec_decode/eagle.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 4 +- 61 files changed, 1338 insertions(+), 1002 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a0d2076199b14..83a7df3b093fc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -890,11 +890,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918fe..ea61c94953a77 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -10,7 +10,7 @@ from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes @@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module): # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend - if backend == _Backend.ROCM_ATTN: + if backend == AttentionBackendEnum.ROCM_ATTN: # k/v as 1st dimention # HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros( @@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module): dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # k/v as 1st dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module): dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.TRITON_ATTN: + elif backend == AttentionBackendEnum.TRITON_ATTN: # k/v as 2nd dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module): dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.FLASHINFER: + elif backend == AttentionBackendEnum.FLASHINFER: kv_cache = torch.zeros( num_blocks, 2, @@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = [] MODELS_FP4: list[tuple[str, type]] = [] HEADS: list[tuple[int, int]] = [] SPLIT_ATTENTION: list[bool] = [] -BACKENDS_FP8: list[_Backend] = [] -BACKENDS_FP4: list[_Backend] = [] +BACKENDS_FP8: list[AttentionBackendEnum] = [] +BACKENDS_FP4: list[AttentionBackendEnum] = [] if current_platform.is_cuda(): HEADS = [(64, 8), (40, 8)] @@ -261,8 +261,8 @@ if current_platform.is_cuda(): TestAttentionNvfp4QuantPatternModel, ) ] - BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] - BACKENDS_FP4 = [_Backend.FLASHINFER] + BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER] + BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER] elif current_platform.is_rocm(): HEADS = [(32, 8), (40, 8)] @@ -270,9 +270,9 @@ elif current_platform.is_rocm(): ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] BACKENDS = [ - _Backend.ROCM_AITER_UNIFIED_ATTN, - _Backend.ROCM_ATTN, - _Backend.TRITON_ATTN, + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.ROCM_ATTN, + AttentionBackendEnum.TRITON_ATTN, ] @@ -302,11 +302,11 @@ def test_attention_quant_pattern( custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, + backend: AttentionBackendEnum, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -314,6 +314,7 @@ def test_attention_quant_pattern( custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") + torch.set_default_dtype(dtype) torch.manual_seed(42) vllm_config = VllmConfig( @@ -402,7 +403,7 @@ def test_attention_quant_pattern( result_fused_1 = model_compiled(q, k, v) - if backend == _Backend.FLASHINFER: + if backend == AttentionBackendEnum.FLASHINFER: # With the Flashinfer backend after the 1st round of the forward # pass, output quant scale should be loaded into the attn layer's # _o_scale_float, the 2nd round should reuse the loaded diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 4b910bc285797..f67063cdf42ea 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -11,7 +11,7 @@ from typing import Any, NamedTuple import pytest import regex as re -from tests.v1.attention.utils import _Backend +from tests.v1.attention.utils import AttentionBackendEnum from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform @@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] - backend: _Backend + backend: AttentionBackendEnum attention_fusions: int allreduce_fusions: int | None = None @@ -39,14 +39,14 @@ if current_platform.is_cuda(): # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -56,7 +56,7 @@ if current_platform.is_cuda(): ModelBackendTestCase( model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=32, allreduce_fusions=65, ), @@ -67,7 +67,7 @@ if current_platform.is_cuda(): ModelBackendTestCase( model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=65, ), @@ -85,19 +85,19 @@ elif current_platform.is_rocm(): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_ATTN, + backend=AttentionBackendEnum.ROCM_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, @@ -125,7 +125,7 @@ def test_attn_quant( caplog_mp_spawn, monkeypatch, ): - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py index b1a09d88ed9d6..3d02893e52f1e 100644 --- a/tests/config/test_multimodal_config.py +++ b/tests/config/test_multimodal_config.py @@ -3,13 +3,13 @@ import pytest -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.multimodal import MultiModalConfig def test_mm_encoder_attn_backend_str_conversion(): config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") - assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN def test_mm_encoder_attn_backend_invalid(): @@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_hash_updates(): base_hash = MultiModalConfig().compute_hash() overridden_hash = MultiModalConfig( - mm_encoder_attn_backend=_Backend.FLASH_ATTN + mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN ).compute_hash() assert base_hash != overridden_hash diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 8149ce7672cdc..29cc81be12e45 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -120,12 +120,13 @@ def test_env( elif device == "cuda": with patch("vllm.platforms.current_platform", CudaPlatform()): + capability = torch.cuda.get_device_capability() if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 - # and Blackwell GPUs (SM 10.0), V1 only + # and Blackwell GPUs (SM 10.x), V1 only # - FLASHINFER_MLA: only supported on Blackwell GPUs - # (SM 10.0+), V1 only + # (SM 10.x), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -134,58 +135,72 @@ def test_env( if block_size != 128: # CUTLASS_MLA only supports block_size == 128 pytest.skip("CUTLASS_MLA only supports block_size 128") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "CUTLASS_MLA" - assert backend.get_name() == expected + if capability[0] != 10: + pytest.skip("CUTLASS MLA is not supported on this platform") + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "CUTLASS_MLA" + assert backend.get_name() == expected elif name == "FLASHINFER_MLA": + if capability[0] != 10: + pytest.skip( + "FlashInfer MLA is not supported on this platform" + ) if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( "FlashInfer MLA only supports block_size 32 or 64" ) - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "FLASHINFER_MLA" - assert backend.get_name() == expected + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") - else: - from vllm.v1.attention.backends.mla.flashmla import ( - is_flashmla_dense_supported, - ) + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_dense_supported, + ) - is_supported, _ = is_flashmla_dense_supported() - if not is_supported: - pytest.skip("FlashMLA not supported on this platform") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = name - assert backend.get_name() == expected - elif name == "FLASH_ATTN_MLA": + is_supported, _ = is_flashmla_dense_supported() + if not is_supported: + pytest.skip("FlashMLA not supported on this platform") backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, + torch.float16, + None, + block_size, + use_mla=use_mla, + ) + expected = name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + ) + + if not flash_attn_supports_mla(): + pytest.skip( + "FlashAttention MLA not supported on this platform" + ) + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 64, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASHINFER" assert backend.get_name() == expected diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 14d1618bca3c5..183bbf3bf4e03 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -11,7 +11,7 @@ from unittest.mock import patch import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import MultiHeadAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform @@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA elif device == "hip": with ( patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention @@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA not available @@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == AttentionBackendEnum.XFORMERS # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA available @@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN def ref_attention( diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 48a6f34366cff..8c4bd6eaa2dd8 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -93,6 +93,17 @@ def can_initialize( "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" ) + if model_arch == "DeepseekV32ForCausalLM": + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability and capability.major < 9: + pytest.skip( + f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) " + f"for FLASHMLA_SPARSE backend. Current device has compute " + f"capability {capability.major}.{capability.minor}" + ) + with ( patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), monkeypatch.context() as m, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 08aeb6f298f61..b46002c5fa8ff 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -15,7 +15,7 @@ from tests.v1.attention.utils import ( create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv @@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import ( from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLASHINFER, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TREE_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [ try: import flashinfer # noqa: F401 except ImportError: - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER) def _convert_dtype_to_torch(dtype): @@ -192,7 +192,7 @@ class MockAttentionLayer: def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -211,13 +211,13 @@ def run_attention_backend( use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") if backend == "FLEX_ATTENTION_SLOW": - actual_backend = _Backend.FLEX_ATTENTION + actual_backend = AttentionBackendEnum.FLEX_ATTENTION use_direct_block_mask = False builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if actual_backend == _Backend.FLASHINFER: + if actual_backend == AttentionBackendEnum.FLASHINFER: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -246,7 +246,7 @@ def run_attention_backend( else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - if actual_backend == _Backend.FLEX_ATTENTION: + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, @@ -289,7 +289,7 @@ def run_attention_backend( def _test_backend_correctness( batch_spec: BatchSpec, model: str, - backend_to_test: list[_Backend | str], + backend_to_test: list[AttentionBackendEnum | str], mask_mod, *, block_size: int = 16, @@ -455,17 +455,20 @@ def _test_backend_correctness( # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache reset_kv_cache_layout = False - if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): kv_cache_for_backend = kv_cache.transpose(0, 1) - if backend_name == _Backend.FLASHINFER: + if backend_name == AttentionBackendEnum.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) ) set_kv_cache_layout("HND") reset_kv_cache_layout = True - elif backend_name == _Backend.TRITON_ATTN: + elif backend_name == AttentionBackendEnum.TRITON_ATTN: kv_cache_for_backend = kv_cache_for_backend.contiguous() try: @@ -547,7 +550,9 @@ def test_causal_backend_correctness( batch_spec = BATCH_SPECS[batch_spec_name] LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS @@ -573,9 +578,9 @@ def test_causal_backend_correctness( SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness( ) LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 5679fafe63ee8..1bd05e6183dc2 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -18,12 +18,11 @@ from tests.v1.attention.utils import ( try_get_attention_backend, ) from vllm import _custom_ops as ops -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.config.vllm import set_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport @@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, - _Backend.FLASHMLA, - _Backend.FLASH_ATTN_MLA, - _Backend.FLASHINFER_MLA, - _Backend.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.TRITON_MLA, ] # Remove sm100 backends from the list if not using sm100 if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: - BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA) # Remove FLASH_ATTN_MLA from the list if not supported if not flash_attn_supports_mla(): - BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA) # Remove FLASHMLA from the list if not supported if not is_flashmla_dense_supported()[0]: - BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) SPEC_DECODE_BACKENDS = [] for backend in BACKENDS_TO_TEST: @@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST: BACKEND_BLOCK_SIZES = {} for backend in BACKENDS_TO_TEST: - backend_class_str = backend_to_class_str(backend) - backend_class = resolve_obj_by_qualname(backend_class_str) - supported_sizes = backend_class.get_supported_kernel_block_size() + supported_sizes = backend.get_class().supported_kernel_block_sizes if supported_sizes: default_size = supported_sizes[0] block_size = ( @@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -813,7 +810,7 @@ def test_backend_correctness( # Create a summary for the single-line failure message backend_names = [] for f in failures: - if "[_Backend." in f: + if "[AttentionBackendEnum." in f: backend_name = f.split("[")[1].split("]")[0] backend_names.append(backend_name) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index b166d9d4ff688..dea89babd4b47 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,7 +8,7 @@ import pytest import torch from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -20,7 +20,6 @@ from vllm.config import ( VllmConfig, ) from vllm.config.model import ModelDType -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -120,15 +119,14 @@ def create_common_attn_metadata( def try_get_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: """Try to get the attention backend class, skipping test if not found.""" - backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_str) + backend_class = backend.get_class() return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_class_str} not available: {e}") + pytest.skip(f"{backend.name} not available: {e}") raise AssertionError("unreachable") from None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65df..89d0ec769ac09 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -13,7 +13,7 @@ from tests.v1.attention.utils import ( create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TRITON_ATTN + ) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 9ca7cf9e3e0e1..6d59b58e739eb 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -12,7 +12,7 @@ from tests.v1.attention.utils import ( create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index b365e75d5514c..6958d62dc7e90 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -10,7 +10,7 @@ from tests.v1.attention.utils import ( create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -35,7 +35,7 @@ def forward_attention( block_table: torch.Tensor, slot_mapping: torch.Tensor, seqlen_k: int, - backend: _Backend, + backend: AttentionBackendEnum, spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: @@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, - backend=_Backend.TREE_ATTN, + backend=AttentionBackendEnum.TREE_ATTN, spec_token_tree=spec_token_tree, num_spec_tokens=tree_size_q - 1, ).view(batch_size, -1, num_heads, dim_per_head) @@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, - backend=_Backend.FLASH_ATTN, + backend=AttentionBackendEnum.FLASH_ATTN, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index bc624658308bf..b02d9a657407b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size( supported_sizes: list[int | MultipleOf], ): class _MockBackend: - @staticmethod - def get_supported_kernel_block_size(): - return supported_sizes + supported_kernel_block_sizes = supported_sizes return _MockBackend() @@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) - expected_kv_cache_shape = [ - 2, - NUM_BLOCKS, - BLOCK_SIZE, - n_heads, - model_runner.model_config.get_head_size(), - ] + head_size = model_runner.model_config.get_head_size() + + # Get the expected shape from the backend's get_kv_cache_shape method + # to ensure compatibility with different backends (triton vs flexattention) + attn_backend = None + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + break + + assert attn_backend is not None, "No attention backend found" + expected_kv_cache_shape = list( + attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) + ) + # TODO mla test default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b54eaf4e2872d..697beed918693 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,13 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args import torch from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey +if TYPE_CHECKING: + from vllm.config.cache import CacheDType + from vllm.platforms.interface import DeviceCapability + from vllm.v1.attention.backends.utils import KVCacheLayoutType + class AttentionType: """ @@ -40,6 +45,9 @@ class AttentionBackend(ABC): # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)] + supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] @staticmethod @abstractmethod @@ -51,10 +59,6 @@ class AttentionBackend(ABC): def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return cls.get_impl_cls().get_supported_kernel_block_size() - @staticmethod @abstractmethod def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: @@ -79,6 +83,136 @@ class AttentionBackend(ABC): def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + supported_head_sizes = cls.get_supported_head_sizes() + return (not supported_head_sizes) or head_size in supported_head_sizes + + @classmethod + def supports_dtype(cls, dtype: torch.dtype) -> bool: + return dtype in cls.supported_dtypes + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: + if kv_cache_dtype is None: + return True + return (not cls.supported_kv_cache_dtypes) or ( + kv_cache_dtype in cls.supported_kv_cache_dtypes + ) + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + from vllm.config.cache import BlockSize + + if block_size is None: + return True + + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: + return False + + if not cls.supported_kernel_block_sizes: + return True + + for supported_size in cls.supported_kernel_block_sizes: + is_multiple_of = ( + isinstance(supported_size, MultipleOf) + and block_size % supported_size.base == 0 + ) + is_int_equal = ( + isinstance(supported_size, int) and block_size == supported_size + ) + if is_multiple_of or is_int_equal: + return True + return False + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def supports_sink(cls) -> bool: + return False + + @classmethod + def is_sparse(cls) -> bool: + return False + + @classmethod + def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: + return True + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> str | None: + return None + + @classmethod + def validate_configuration( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> list[str]: + invalid_reasons = [] + if not cls.supports_head_size(head_size): + invalid_reasons.append("head_size not supported") + if not cls.supports_dtype(dtype): + invalid_reasons.append("dtype not supported") + if not cls.supports_kv_cache_dtype(kv_cache_dtype): + invalid_reasons.append("kv_cache_dtype not supported") + if not cls.supports_block_size(block_size): + invalid_reasons.append("block_size not supported") + if use_mla != cls.is_mla(): + if use_mla: + invalid_reasons.append("MLA not supported") + else: + invalid_reasons.append("non-MLA not supported") + if has_sink and not cls.supports_sink(): + invalid_reasons.append("sink setting not supported") + if use_sparse != cls.is_sparse(): + if use_sparse: + invalid_reasons.append("sparse not supported") + else: + invalid_reasons.append("non-sparse not supported") + if not cls.supports_compute_capability(device_capability): + invalid_reasons.append("compute capability not supported") + combination_reason = cls.supports_combination( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + if combination_reason is not None: + invalid_reasons.append(combination_reason) + return invalid_reasons + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return None + class AttentionMetadata: pass @@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]): ) -> None: raise NotImplementedError - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # TODO: implement this function for all backends. - return [MultipleOf(1)] - @abstractmethod def forward( self, diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 05d0159d08615..768d15cb9c82b 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,108 +3,192 @@ """Attention backend registry""" import enum +from collections.abc import Callable +from typing import TYPE_CHECKING, cast +from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_ATTN = enum.auto() - ROCM_AITER_MLA = enum.auto() - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() - FLASHMLA_SPARSE = enum.auto() - FLASH_ATTN_MLA = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_AITER_UNIFIED_ATTN = enum.auto() +logger = init_logger(__name__) -BACKEND_MAP = { - _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 - _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 - _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 - _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 -} +class _AttentionBackendEnumMeta(enum.EnumMeta): + """Metaclass for AttentionBackendEnum to provide better error messages.""" + + def __getitem__(cls, name: str): + """Get backend by name with helpful error messages.""" + try: + return super().__getitem__(name) + except KeyError: + members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() + valid_backends = ", ".join(m.name for m in members) + raise ValueError( + f"Unknown attention backend: '{name}'. " + f"Valid options are: {valid_backends}" + ) from None -def register_attn_backend(backend: _Backend, class_path: str | None = None): +class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): + """Enumeration of all supported attention backends. + + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). + + To get the actual backend class (respecting overrides), use: + backend.get_class() """ - Decorator: register a custom attention backend into BACKEND_MAPPING. - - If class_path is provided, use it. - - Otherwise, auto-generate from the class object. - Validation: only checks if 'backend' is a valid _Backend enum member. - Overwriting existing mappings is allowed. This enables other hardware - platforms to plug in custom out-of-tree backends. - """ - if not isinstance(backend, _Backend): - raise ValueError(f"{backend} is not a valid _Backend enum value.") - def decorator(cls): - path = class_path or f"{cls.__module__}.{cls.__qualname__}" - BACKEND_MAP[backend] = path + FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" + ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" + ROCM_AITER_FA = ( + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ) + TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + FLASHINFER_MLA = ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) + TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" + FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" + FLASHMLA_SPARSE = ( + "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" + ) + FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" + NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" + FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" + ROCM_AITER_UNIFIED_ATTN = ( + "vllm.v1.attention.backends.rocm_aiter_unified_attn." + "RocmAiterUnifiedAttentionBackend" + ) + # Placeholder for third-party/custom backends - must be registered before use + CUSTOM = "" + + def get_path(self, include_classname: bool = True) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If Backend.CUSTOM is used without being registered + """ + path = _OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" + ) + if not include_classname: + path = path.rsplit(".", 1)[0] + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If Backend.CUSTOM is used without being registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _OVERRIDES.pop(self, None) + + +_OVERRIDES: dict[AttentionBackendEnum, str] = {} + + +def register_backend( + backend: AttentionBackendEnum, class_path: str | None = None +) -> Callable[[type], type]: + """Register or override a backend implementation. + + Args: + backend: The AttentionBackendEnum member to register + class_path: Optional class path. If not provided and used as + decorator, will be auto-generated from the class. + + Returns: + Decorator function if class_path is None, otherwise a no-op + + Examples: + # Override an existing backend + @register_backend(AttentionBackendEnum.FLASH_ATTN) + class MyCustomFlashAttn: + ... + + # Register a custom third-party backend + @register_backend(AttentionBackendEnum.CUSTOM) + class MyCustomBackend: + ... + + # Direct registration + register_backend( + AttentionBackendEnum.CUSTOM, + "my.module.MyCustomBackend" + ) + """ + + def decorator(cls: type) -> type: + _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" return cls + if class_path is not None: + _OVERRIDES[backend] = class_path + return lambda x: x + return decorator -def backend_to_class_str(backend: _Backend) -> str: - """Get the backend class string +# Backwards compatibility alias for plugins +class _BackendMeta(type): + """Metaclass to provide deprecation warnings when accessing _Backend.""" - Args: - backend: The backend enum value + def __getattribute__(cls, name: str): + if name not in ("__class__", "__mro__", "__name__"): + logger.warning( + "_Backend has been renamed to AttentionBackendEnum. " + "Please update your code to use AttentionBackendEnum instead. " + "_Backend will be removed in a future release." + ) + return getattr(AttentionBackendEnum, name) - Returns: - The backend class string + def __getitem__(cls, name: str): + logger.warning( + "_Backend has been renamed to AttentionBackendEnum. " + "Please update your code to use AttentionBackendEnum instead. " + "_Backend will be removed in a future release." + ) + return AttentionBackendEnum[name] + + +class _Backend(metaclass=_BackendMeta): + """Deprecated: Use AttentionBackendEnum instead. + + This class is provided for backwards compatibility with plugins + and will be removed in a future release. """ - return BACKEND_MAP[backend] - -def backend_to_class(backend: _Backend) -> type: - """Get the backend class. - - Args: - backend: The backend enum value - - Returns: - The backend class - """ - backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) - - -def backend_name_to_enum(backend_name: str) -> _Backend | None: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - _Backend: enum value if backend_name is a valid in-tree type - None: otherwise it's an invalid in-tree type or an out-of-tree platform - is loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None + pass diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index acab0529f3520..ec705126c710d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -12,7 +12,7 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, use_upstream_fa: bool, - attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable | None]: + attn_backend_override: AttentionBackendEnum | None = None, +) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = _Backend.ROCM_AITER_FA + attn_backend = AttentionBackendEnum.ROCM_AITER_FA elif ( check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None elif current_platform.is_cuda(): - if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True elif current_platform.is_xpu(): - assert attn_backend == _Backend.FLASH_ATTN, ( + assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( "XPU platform only supports FLASH_ATTN as vision attention backend." ) use_upstream_fa = False else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: - if attn_backend == _Backend.ROCM_AITER_FA: + if attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: + if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: @@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module): backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.PALLAS, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, } - else _Backend.TORCH_SDPA + else AttentionBackendEnum.TORCH_SDPA ) self.attn_backend, self._flash_attn_varlen_func = ( @@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module): ) ) - if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): - self.attn_backend = _Backend.TORCH_SDPA + if ( + self.attn_backend == AttentionBackendEnum.XFORMERS + and not check_xformers_availability() + ): + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): use_upstream_fa = True logger.info_once( @@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module): max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward( query, key, value, scale=self.scale ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS: + elif self.attn_backend == AttentionBackendEnum.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9c26a8d40edaf..6e5fa854d35f5 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -4,14 +4,15 @@ import os from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass from functools import cache +from typing import cast, get_args import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils.import_utils import resolve_obj_by_qualname @@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) -def get_env_variable_attn_backend() -> _Backend | None: +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. Returns: - * _Backend enum value if an override is specified + * AttentionBackendEnum value if an override is specified * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else backend_name_to_enum(backend_name) + return None if backend_name is None else AttentionBackendEnum[backend_name] # Global state allows a particular choice of backend @@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None: # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: _Backend | None = None +forced_attn_backend: AttentionBackendEnum | None = None -def global_force_attn_backend(attn_backend: _Backend | None) -> None: +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: """ Force all attention operations to use a specified backend. @@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None: forced_attn_backend = attn_backend -def get_global_forced_attn_backend() -> _Backend | None: +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. @@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None: return forced_attn_backend -@dataclass(frozen=True) -class _IsSupported: - can_import: bool - head_size: bool - dtype: bool - - def __bool__(self) -> bool: - return self.can_import and self.head_size and self.dtype - - -def is_attn_backend_supported( - attn_backend: str | type[AttentionBackend], - head_size: int, - dtype: torch.dtype, - *, - allow_import_error: bool = True, -) -> _IsSupported: - if isinstance(attn_backend, str): - try: - attn_backend = resolve_obj_by_qualname(attn_backend) - except ImportError: - if not allow_import_error: - raise - - return _IsSupported(can_import=False, head_size=False, dtype=False) - - assert isinstance(attn_backend, type) - - # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr( - attn_backend, "get_supported_head_sizes", None - ): - is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", None): - try: - validate_head_size(head_size) - is_head_size_supported = True - except Exception: - is_head_size_supported = False - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support head size validation" - ) - - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): - is_dtype_supported = dtype in get_supported_dtypes() - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support dtype validation" - ) - - return _IsSupported( - can_import=True, - head_size=is_head_size_supported, - dtype=is_dtype_supported, - ) - - def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + + if kv_cache_dtype is not None: + valid_cache_dtypes = get_args(CacheDType) + assert kv_cache_dtype in valid_cache_dtypes, ( + f"Invalid kv_cache_dtype: {kv_cache_dtype}. " + f"Valid values are: {valid_cache_dtypes}" + ) + return _cached_get_attn_backend( head_size=head_size, dtype=dtype, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), block_size=block_size, use_mla=use_mla, has_sink=has_sink, @@ -149,8 +100,8 @@ def get_attn_backend( def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, + kv_cache_dtype: CacheDType | None, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -161,7 +112,9 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -177,12 +130,13 @@ def _cached_get_attn_backend( STR_BACKEND_ENV_VAR, ) backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: + try: + selected_backend = AttentionBackendEnum[backend_by_env_var] + except KeyError as e: raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}" - ) + f"Invalid attention backend: '{backend_by_env_var}'. Valid " + f"backends are: {list(AttentionBackendEnum.__members__.keys())}" + ) from e # get device-specific attn_backend from vllm.platforms import current_platform @@ -202,12 +156,26 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - return resolve_obj_by_qualname(attention_cls) + backend = resolve_obj_by_qualname(attention_cls) + + # Adjust kv cache layout if the selected backend requires a specific one + required_layout = backend.get_required_kv_cache_layout() + if required_layout is not None: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + backend.get_name(), + ) + + return backend @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, ) -> Generator[None, None, None]: """ Globally force a vLLM attention backend override within a diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 031df3091f1c6..864cf1be81b20 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -21,7 +21,15 @@ else: logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] -CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + "fp8_inc", + "fp8_ds_mla", +] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] diff --git a/vllm/config/model.py b/vllm/config/model.py index 44c044c76168d..6ce91ebb87b90 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -53,7 +53,7 @@ if TYPE_CHECKING: else: PretrainedConfig = Any - _Backend = Any + AttentionBackendEnum = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -302,7 +302,7 @@ class ModelConfig: mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None - mm_encoder_attn_backend: InitVar[_Backend | str | None] = None + mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None @@ -420,7 +420,7 @@ class ModelConfig: mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, mm_encoder_tp_mode: MMEncoderTPMode | None, - mm_encoder_attn_backend: _Backend | str | None, + mm_encoder_attn_backend: AttentionBackendEnum | str | None, interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ef73720efe099..9348c1b2af8cc 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass from vllm.config.utils import config if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum else: - _Backend = Any + AttentionBackendEnum = Any @dataclass @@ -125,10 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" - mm_encoder_attn_backend: _Backend | None = None + mm_encoder_attn_backend: AttentionBackendEnum | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from - `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" + `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -167,26 +167,16 @@ class MultiModalConfig: @field_validator("mm_encoder_attn_backend", mode="before") @classmethod - def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: - from vllm.attention.backends.registry import ( - _Backend as BackendEnum, - ) - from vllm.attention.backends.registry import ( - backend_name_to_enum, - ) - - if value is None or isinstance(value, BackendEnum): + def _validate_mm_encoder_attn_backend( + cls, value: str | AttentionBackendEnum | None + ) -> AttentionBackendEnum | None: + if value is None or isinstance(value, AttentionBackendEnum): return value - if isinstance(value, str): - candidate = backend_name_to_enum(value.upper()) - if candidate is not None: - return candidate - - valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) - raise ValueError( - f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + assert isinstance(value, str), ( + "mm_encoder_attn_backend must be a string or an AttentionBackendEnum." ) + return AttentionBackendEnum[value.upper()] @model_validator(mode="after") def _validate_multimodal_config(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff9770b72bd38..6c20eee1ecbf9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ import torch import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -876,9 +876,9 @@ class NixlConnectorWorker: use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER - self._use_pallas = attn_backend == _Backend.PALLAS + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b12b7082af627..d3913553320fd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,7 +32,7 @@ from pydantic.fields import FieldInfo from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -462,7 +462,7 @@ class EngineArgs: MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode - mm_encoder_attn_backend: _Backend | str | None = ( + mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( MultiModalConfig.mm_encoder_attn_backend ) io_processor_plugin: str | None = None diff --git a/vllm/envs.py b/vllm/envs.py index 52178e5f52500..52a9671bc46e2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -626,14 +626,14 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA - # All possible options loaded dynamically from _Backend enum + # All possible options loaded dynamically from AttentionBackendEnum "VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND", None, lambda: list( __import__( - "vllm.attention.backends.registry", fromlist=["_Backend"] - )._Backend.__members__.keys() + "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] + ).AttentionBackendEnum.__members__.keys() ), ), # If set, vllm will use flashinfer sampler diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 6d462ad8ae620..1b2bb60a17c16 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module): ) ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Unsupported vision attention backend: {self.attn_backend}" ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module): self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): s = int(cu_seqlens[i - 1]) @@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module): out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module): require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( @@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module): ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index f287cff12086b..97182a25f82b8 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -36,7 +36,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from transformers import BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module): projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module): context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module): context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module): norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b9cd3545ec453..776527fdd973a 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import ( from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module): context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module): context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module): max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 42f16ad9f3b3a..80d7e6c5b0cd0 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( maybe_get_vit_flash_attn_backend, ) @@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module): softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f6461ae9a412e..9a4d69dea0968 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 12ae15699e7d2..86d7d1c11ffe8 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -31,7 +31,7 @@ from transformers.modeling_outputs import ( ) from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -580,8 +580,8 @@ class SiglipAttention(nn.Module): projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend: _Backend = _Backend.TORCH_SDPA, - attn_backend_override: _Backend | None = None, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, + attn_backend_override: AttentionBackendEnum | None = None, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -621,8 +621,8 @@ class SiglipAttention(nn.Module): ) ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -680,10 +680,10 @@ class SiglipAttention(nn.Module): cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] @@ -702,7 +702,7 @@ class SiglipAttention(nn.Module): context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: if seqlens is None: raise ValueError("xFormers attention backend requires seqlens tensor.") context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) @@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", *, - attn_backend: _Backend = _Backend.TORCH_SDPA, - attn_backend_override: _Backend | None = None, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, + attn_backend_override: AttentionBackendEnum | None = None, use_upstream_fa: bool = False, ): super().__init__() @@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module): ) self.use_upstream_fa = False if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } and check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"PaddleOCR-VL does not support {self.attn_backend} backend now." @@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module): max_seqlen = None seqlens = None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = inputs_embeds @@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module): config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module): config, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 48834ba699e4c..3292cf8220ffe 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLVisionConfig, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, @@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module): # On ROCm with FLASH_ATTN backend, upstream flash_attn is used from vllm.platforms import current_platform - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): self.use_upstream_fa = True if current_platform.is_xpu(): self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module): cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module): v, cu_seqlens, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) @@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b3999e6c934e3..61057fa145f47 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import ( from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module): context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module): context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module): self, cu_seqlens: torch.Tensor ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index da489a812f55d..468b25220154b 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( ) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1cd34bf54a35f..1be35cde7dbdc 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( ) from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module): ) use_upstream_fa = False if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." @@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module): max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bab5c1d82deda..c20bcd975ca30 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -12,7 +12,7 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module): ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.ROCM_AITER_FA, }: - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module): attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(seq_length, -1) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 outputs = [] @@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 9f94387c700d6..0e814e5c86ad4 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar import torch from transformers import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -83,8 +83,8 @@ def get_vit_attn_backend( head_size: int, dtype: torch.dtype, *, - attn_backend_override: _Backend | None = None, -) -> _Backend: + attn_backend_override: AttentionBackendEnum | None = None, +) -> AttentionBackendEnum: """ Get the available attention backend for Vision Transformer. """ @@ -94,7 +94,7 @@ def get_vit_attn_backend( # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend - selected_backend: _Backend | None = get_env_variable_attn_backend() + selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index ee904535ffe8d..3dec6da897025 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None VllmConfig = None @@ -127,7 +127,7 @@ class CpuPlatform(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -137,9 +137,9 @@ class CpuPlatform(Platform): has_sink: bool, use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend and selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") @@ -148,7 +148,7 @@ class CpuPlatform(Platform): logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + return AttentionBackendEnum.TORCH_SDPA.get_path() @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 32734c3aba5ef..43daf5e75b665 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType else: - _Backend = None + AttentionBackendEnum = None + VllmConfig = None + CacheDType = None logger = init_logger(__name__) @@ -39,6 +42,49 @@ pynvml = import_pynvml() torch.backends.cuda.enable_cudnn_sdp(False) +@cache +def _get_backend_priorities( + use_mla: bool, + device_capability: DeviceCapability, +) -> list[AttentionBackendEnum]: + """Get backend priorities with lazy import to avoid circular dependency.""" + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + if device_capability.major == 10: + return [ + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + return [ + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + if device_capability.major == 10: + return [ + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + ] + else: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + ] + + def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @@ -216,217 +262,171 @@ class CudaPlatformBase(Platform): return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + from vllm.attention.backends.registry import AttentionBackendEnum # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 if cls.has_device_capability(100): - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS if cls.has_device_capability(80): - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - from vllm.attention.selector import is_attn_backend_supported - - is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ) - if is_default_fa_supported: - return _Backend.FLASH_ATTN + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN else: - # Fallback to XFORMERS - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS @classmethod - def get_attn_backend_cls( + def get_valid_backends( cls, - selected_backend, head_size, dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, + device_capability, + ) -> tuple[ + list[tuple["AttentionBackendEnum", int]], + dict["AttentionBackendEnum", list[str]], + ]: + valid_backends_priorities = [] + invalid_reasons = {} + + backend_priorities = _get_backend_priorities(use_mla, device_capability) + for priority, backend in enumerate(backend_priorities): + try: + backend_class = backend.get_class() + invalid_reasons_i = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + except ImportError: + invalid_reasons_i = ["ImportError"] + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) + + return valid_backends_priorities, invalid_reasons + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." + ) - if use_mla: - # explicitly reject non-MLA backends when MLA is enabled to avoid - # silently selecting an incompatible backend (e.g., FLASHINFER). - if selected_backend in { - _Backend.FLASHINFER, - _Backend.FLASH_ATTN, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, - _Backend.XFORMERS, - }: + device_capability = cls.get_device_capability() + assert device_capability is not None + + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: raise ValueError( - f"Attention backend {selected_backend} incompatible with MLA. " - "Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, " - "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " - "VLLM_MLA_DISABLE=1 to disable MLA for this model." + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" ) + else: + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() - from vllm.attention.ops.flashmla import is_flashmla_dense_supported - from vllm.attention.utils.fa_utils import flash_attn_supports_mla - - if use_sparse: - logger.info_once("Using Sparse MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashmla_sparse." - "FlashMLASparseBackend" - ) - - use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and block_size % 128 == 0 - ) - use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and (block_size == 32 or block_size % 64 == 0) - ) - use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_dense_supported()[0] - ) - use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( - selected_backend is None and flash_attn_supports_mla() - ) - use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None - ) - - if use_cutlassmla: - logger.info_once("Using Cutlass MLA backend.", scope="local") - return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" - if use_flashinfermla: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" - ) - if use_flashmla: - if block_size % 64 != 0: - logger.warning( - "FlashMLA backend is not supported for block size %d" - " (currently only supports block size 64).", - block_size, - ) - else: - logger.info_once("Using FlashMLA backend.") - return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" - if use_flashattn: - logger.info_once("Using FlashAttention MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" - ) - if use_triton: - logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, ) - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = ( + f"head_size: {head_size}, dtype: {dtype}, " + f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " + f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" + ) + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" + # We have found some valid backends. Select the one with the + # highest priority. + logger.info( + "Valid backends: %s", [b[0].name for b in valid_backends_priorities] + ) + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info( + "Using %s backend.", + selected_backend.name, ) - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend.") - return XFORMERS_V1 - - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; " - "it is recommended to install FlashInfer for better " - "performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - - assert not is_default_backend_supported - - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype - - logger.info_once( - "Using FlexAttention backend for %s.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), - ) - return FLEX_ATTENTION_V1 + return selected_backend.get_path() @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 15e3b3a22bdee..4969bcf116a49 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,8 +17,9 @@ from vllm.logger import init_logger if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple): major: int minor: int + def __lt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) < (other.major, other.minor) + + def __le__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) <= (other.major, other.minor) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) >= (other.major, other.minor) + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) > (other.major, other.minor) + def as_version_str(self) -> str: return f"{self.major}.{self.minor}" @@ -173,19 +199,21 @@ class Platform: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - # Import _Backend here to avoid circular import. - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + # Import AttentionBackendEnum here to avoid circular import. + from vllm.attention.backends.registry import AttentionBackendEnum - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int, use_v1: bool, use_mla: bool, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e6536a02a73dd..5318bdb8b36c0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -204,21 +204,23 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> AttentionBackendEnum: from importlib.util import find_spec from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if rocm_aiter_ops.is_mha_enabled(): # Note: AITER FA is only supported for Qwen-VL models. # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA + return AttentionBackendEnum.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: - return _Backend.FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( @@ -234,7 +236,7 @@ class RocmPlatform(Platform): use_sparse, ) -> str: from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") @@ -248,55 +250,52 @@ class RocmPlatform(Platform): if use_mla: if selected_backend is None: selected_backend = ( - _Backend.ROCM_AITER_MLA + AttentionBackendEnum.ROCM_AITER_MLA if rocm_aiter_ops.is_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA + else AttentionBackendEnum.TRITON_MLA ) - if selected_backend == _Backend.TRITON_MLA: + if selected_backend == AttentionBackendEnum.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + return AttentionBackendEnum.TRITON_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." ) - if selected_backend == _Backend.ROCM_AITER_MLA: + if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: logger.info("Using AITER MLA backend.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + return AttentionBackendEnum.ROCM_AITER_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"is not MLA type while requested for MLA backend." ) - if selected_backend == _Backend.FLEX_ATTENTION: + if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( rocm_aiter_ops.is_mha_enabled() - ) or selected_backend == _Backend.ROCM_AITER_FA: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") - return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + return AttentionBackendEnum.ROCM_AITER_FA.get_path() if ( rocm_aiter_ops.is_triton_unified_attn_enabled() - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() if ( envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN + or selected_backend == AttentionBackendEnum.ROCM_ATTN ): # rocm specific backend, with aiter and/or # triton prefix-prefill logger.info("Using Rocm Attention backend.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + return AttentionBackendEnum.ROCM_ATTN.get_path() # default case, using triton unified attention logger.info("Using Triton Attention backend.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + return AttentionBackendEnum.TRITON_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 1a4b67a1762f3..575a9892c2118 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams else: BlockSize = None - ModelConfig = None VllmConfig = None PoolingParams = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -54,7 +53,7 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -64,17 +63,17 @@ class TpuPlatform(Platform): has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != _Backend.PALLAS: + if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + return AttentionBackendEnum.PALLAS.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e4ecd0c807dac..684d6d9a6b57d 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig else: - ModelConfig = None VllmConfig = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -44,7 +43,7 @@ class XPUPlatform(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -62,18 +61,19 @@ class XPUPlatform(Platform): "only NHD layout is supported by XPU attention kernels." ) - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN: + use_v1 = envs.VLLM_USE_V1 + if not use_v1: + raise ValueError("XPU backend only supports V1.") + if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: + return AttentionBackendEnum.TRITON_ATTN.get_path() + elif selected_backend == AttentionBackendEnum.FLASH_ATTN: logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN.get_path() elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " @@ -81,7 +81,7 @@ class XPUPlatform(Platform): ) logger.info("Using Flash Attention backend.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + return AttentionBackendEnum.FLASH_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: @@ -113,10 +113,10 @@ class XPUPlatform(Platform): return device_props.total_memory @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: - from vllm.attention.backends.registry import _Backend - - return _Backend.FLASH_ATTN + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> AttentionBackendEnum: + return AttentionBackendEnum.FLASH_ATTN @classmethod def inference_mode(cls): diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 20d987fa2de3b..0057a7e22882b 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -40,23 +40,16 @@ logger = init_logger(__name__) class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: + def get_supported_head_sizes(cls) -> list[int]: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) - if not is_valid: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + return attn_impl.get_supported_head_sizes() @staticmethod def get_name() -> str: @@ -759,9 +752,8 @@ def _make_sliding_window_bias( class _PagedAttention: @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] - return head_size in SUPPORT_HS, SUPPORT_HS + def get_supported_head_sizes() -> list[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( @@ -861,8 +853,8 @@ class _PagedAttention: class _IPEXPagedAttention(_PagedAttention): @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - return True, [] + def get_supported_head_sizes() -> list[int]: + return [] @staticmethod def split_kv_cache( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 15bb2f4a40acb..9cec623814c9f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import numpy as np import torch @@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available(): reshape_and_cache_flash, ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -52,34 +55,12 @@ logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - return [16, 32, 64] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] @staticmethod def get_name() -> str: @@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend): else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + @dataclass class FlashAttentionMetadata: @@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads - FlashAttentionBackend.validate_head_size(head_size) - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 683725b95819f..07a0ab41a9e05 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import ( MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kNvfp4Quant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( can_use_trtllm_attention, @@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + KVCacheLayoutType, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - return [64, 128, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - return [16, 32, 64] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend): else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + return [64, 128, 256] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( + 12, 1 + ) + + @classmethod + def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability is not None and capability.major == 10: + return "HND" + return None + @dataclass class FlashInferMetadata: @@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size - FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 9af63831cecba..e53cd0d8af4f2 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from typing import ClassVar import torch import torch._dynamo.decorators @@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import ( is_quantized_kv_cache, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - return # FlexAttention supports any head size + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] @staticmethod def get_name() -> str: @@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend): def use_cascade_attention(*args, **kwargs) -> bool: return False + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( @@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl): if kv_sharing_target_layer_name is not None: raise NotImplementedError("FlexAttention does not support kv sharing yet.") - FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet" diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e38f7bcfa44e1..b4cb5c200da38 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def is_mla(cls) -> bool: + return True @dataclass @@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]): ) = None def __post_init__(self): - if self.head_dim is not None: - MLACommonBackend.validate_head_size(self.head_dim) + if self.head_dim is not None and not MLACommonBackend.supports_head_size( + self.head_dim + ): + raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") M = TypeVar("M", bound=MLACommonMetadata) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index c35e238eac4c0..0a10ce74cd1d4 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import ( MultipleOf, is_quantized_kv_cache, ) +from vllm.config.cache import CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend): def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [128] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 79b89c7890a28..5662acbe32c29 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -10,6 +10,7 @@ from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.utils.fa_utils import ( @@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import ( get_flash_attn_version, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -37,6 +40,10 @@ logger = init_logger(__name__) class FlashAttnMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend): def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 9 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if not flash_attn_supports_mla(): + return "FlashAttention MLA not supported on this device" + return None + @dataclass class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ebbcfd0eaa2fb..b0f514ba44513 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,8 +6,14 @@ from typing import ClassVar import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + MultipleOf, +) +from vllm.config.cache import CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType logger = init_logger(__name__) @@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend): return FlashInferMLAMetadataBuilder @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [32, 64] + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return "HND" g_fi_workspace = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 708bb9d63839d..8f0364cd58def 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import ( is_flashmla_dense_supported, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -36,6 +38,14 @@ logger = init_logger(__name__) class FlashMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHMLA" @@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend): def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [64] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if use_sparse: + from vllm.attention.ops.flashmla import is_flashmla_sparse_supported + + return is_flashmla_sparse_supported()[1] + else: + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + + return is_flashmla_dense_supported()[1] @dataclass diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bf76549de1ce8..4794312eb96ef 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, + MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( @@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import ( get_mla_metadata, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl @@ -51,6 +54,9 @@ structured as: class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] @staticmethod def get_name() -> str: @@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend): def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend): else: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - @dataclass class FlashMLASparseMetadata: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index f3c5bb7328712..4f071145625fc 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -23,6 +23,8 @@ logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 128] @@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend): def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [64] - @dataclass class DeepseekV32IndexerPrefillChunkMetadata: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319a..0149639e8c0b3 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ClassVar import torch @@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -28,6 +31,9 @@ logger = init_logger(__name__) class TritonMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend): def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index e8d3758a6395a..81991244f5d90 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,6 +3,7 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - AiterFlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 57ba4dc78d9fd..1d2c70f65d0f5 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend): @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + if not cls.supports_head_size(head_size): attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Supported head sizes are: {cls.get_supported_head_sizes()}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 0c0222d6152fb..1bf38ed225a4c 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -30,31 +30,13 @@ logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "TREE_ATTN" @@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl): else: self.sliding_window = (sliding_window - 1, 0) - TreeAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 0590a87bf8e5f..37c0ae61e65d0 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - # Triton Attention supports any head size above 32 - if head_size < 32: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention." - f"Head sizes need to be larger or equal 32 for this backend. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend): def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 + + @classmethod + def supports_sink(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads - TritonAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 81bdbd641429a..d15d79417cc61 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -41,10 +41,8 @@ logger = init_logger(__name__) class XFormersAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend): 256, ] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "XFORMERS" @@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl): logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - XFormersAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75a4140fd6552..55b04949ceb2a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -150,11 +150,15 @@ class EagleProposer: ) # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import AttentionBackendEnum + self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] - # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend - if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + # ROCM_AITER_FA is an optional backend + if find_spec( + AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) + ): from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6fccf2ea2f47c..790649b69e5c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4371,7 +4371,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ for backend in backends: is_supported = False - for supported_size in backend.get_supported_kernel_block_size(): + for supported_size in backend.supported_kernel_block_sizes: if isinstance(supported_size, int): if block_size == supported_size: is_supported = True @@ -4402,7 +4402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): all_int_supported_sizes = set( supported_size for backend in backends - for supported_size in backend.get_supported_kernel_block_size() + for supported_size in backend.supported_kernel_block_sizes if isinstance(supported_size, int) ) From 7dbe6d81d6f17abe93389d97d417e4886467546f Mon Sep 17 00:00:00 2001 From: Chaojun Zhang Date: Tue, 11 Nov 2025 20:46:47 +0800 Subject: [PATCH 10/98] Fix Fused MoE LoRA Triton kernel bug (#28450) Signed-off-by: chaojun-zhang --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 6d6de2529de3d..893972144e99a 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -26,7 +26,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): tensor_ptrs = [] for lora_weight in lora_weights: tensor_ptrs.append(lora_weight.data_ptr()) - ptr_tensor = torch.tensor(tensor_ptrs, device=device) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) _LORA_PTR_DICT[key] = ptr_tensor return _LORA_PTR_DICT.get(key) @@ -85,6 +85,7 @@ def _fused_moe_lora_kernel( GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, USE_GDC: tl.constexpr, + launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, ): pid = tl.program_id(axis=0) From afffd3cc8a99ce1cf0f6f1687852e5519d725a3b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Nov 2025 21:14:48 +0800 Subject: [PATCH 11/98] [Model] Pass `mm_features` directly into `get_mrope_input_positions` (#28399) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/ernie45_vl.py | 35 +++++------- vllm/model_executor/models/glm4_1v.py | 32 +++++------ vllm/model_executor/models/glm4v.py | 32 +++++------ vllm/model_executor/models/interfaces.py | 22 ++------ vllm/model_executor/models/keye.py | 29 ++++------ vllm/model_executor/models/keye_vl1_5.py | 29 ++++------ vllm/model_executor/models/paddleocr_vl.py | 29 ++++------ .../models/qwen2_5_omni_thinker.py | 46 +++++++++------- vllm/model_executor/models/qwen2_5_vl.py | 36 ++++++------ vllm/model_executor/models/qwen2_vl.py | 37 +++++-------- .../models/qwen3_omni_moe_thinker.py | 55 +++++++++++-------- vllm/model_executor/models/qwen3_vl.py | 30 ++++------ .../models/transformers/multimodal.py | 39 +++++++++---- vllm/multimodal/inputs.py | 13 +++++ vllm/v1/worker/gpu_model_runner.py | 33 ++--------- 15 files changed, 225 insertions(+), 272 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 97182a25f82b8..c040b19bba20e 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( @@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -1433,15 +1434,16 @@ class Ernie4_5_VLMoeForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for Ernie VL.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.im_patch_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id @@ -1449,10 +1451,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( temporal_conv_size = hf_config.temporal_conv_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1484,11 +1483,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_conv_size, @@ -1519,11 +1514,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( mm_data_idx += 1 elif modality_type == "video": - t, h, w = ( - video_grid_thw[mm_data_idx][0], - video_grid_thw[mm_data_idx][1], - video_grid_thw[mm_data_idx][2], - ) + t, h, w = video_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t // temporal_conv_size, h // spatial_conv_size, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 776527fdd973a..60cad2e2907f2 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -37,7 +37,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, @@ -70,6 +70,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1619,25 +1620,23 @@ class Glm4vForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1669,11 +1668,7 @@ class Glm4vForConditionalGeneration( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -1706,8 +1701,7 @@ class Glm4vForConditionalGeneration( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index ebf6934dddead..899797a510539 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -15,7 +15,7 @@ from torch import nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -36,6 +36,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -622,25 +623,23 @@ class GLM4VForCausalLM( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -672,11 +671,7 @@ class GLM4VForCausalLM( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -709,8 +704,7 @@ class GLM4VForCausalLM( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d6a8f86d998bb..88b45bf07c0d8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -16,7 +16,6 @@ import numpy as np import torch import torch.nn as nn from torch import Tensor -from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs @@ -32,10 +31,12 @@ from .interfaces_base import VllmModel, is_pooling_model if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.sequence import IntermediateTensors else: VllmConfig = object WeightsMapper = object + MultiModalFeatureSpec = object IntermediateTensors = object logger = init_logger(__name__) @@ -991,12 +992,7 @@ class SupportsMRoPE(Protocol): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list["MultiModalFeatureSpec"], ) -> tuple[torch.Tensor, int]: """ Get M-RoPE input positions and delta value for this specific model. @@ -1006,17 +1002,11 @@ class SupportsMRoPE(Protocol): Args: input_tokens: List of input token IDs - hf_config: HuggingFace model configuration - image_grid_thw: Image grid dimensions (t, h, w) - video_grid_thw: Video grid dimensions (t, h, w) - second_per_grid_ts: Seconds per grid timestep for videos - audio_feature_lengths: Audio feature lengths for multimodal models - use_audio_in_video: Whether to use audio in video for interleaving + mm_features: Information about each multi-modal data item Returns: - Tuple of (llm_positions, mrope_position_delta) - - llm_positions: Tensor of shape [3, num_tokens] - with T/H/W positions + Tuple of `(llm_positions, mrope_position_delta)` + - llm_positions: Tensor of shape `[3, num_tokens]` with T/H/W positions - mrope_position_delta: Delta for position calculations """ ... diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 80d7e6c5b0cd0..aa0134badc402 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -40,6 +40,7 @@ from vllm.multimodal.inputs import ( ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1627,16 +1628,17 @@ class KeyeForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -1662,6 +1664,7 @@ class KeyeForConditionalGeneration( video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -1691,20 +1694,12 @@ class KeyeForConditionalGeneration( ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 6f95a59d36d29..124e9c2afa217 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, ModalityData, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -597,16 +598,17 @@ class KeyeVL1_5ForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -632,6 +634,7 @@ class KeyeVL1_5ForConditionalGeneration( video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -661,20 +664,12 @@ class KeyeVL1_5ForConditionalGeneration( ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 86d7d1c11ffe8..62994abe8e317 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -61,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargs, ) @@ -1184,15 +1185,17 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1229,20 +1232,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index fac281d2caf49..8f74cab0534da 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -68,6 +68,7 @@ from vllm.multimodal.inputs import ( ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, @@ -923,21 +924,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - + """ Example: (V_i are vision position ids, A_i are audio position ids) @@ -945,11 +934,33 @@ class Qwen2_5OmniThinkerForConditionalGeneration( |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... """ + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) # TODO(fyabc): refactor and share more code with # _vl_get_input_positions_tensor. - thinker_config = hf_config.thinker_config + thinker_config = self.config audio_token_id = thinker_config.audio_token_index image_token_id = thinker_config.image_token_index video_token_id = thinker_config.video_token_index @@ -963,11 +974,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( thinker_config.vision_config, "tokens_per_second", 25 ) - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - src_item = input_tokens audio_seqlens = audio_feature_lengths if not second_per_grid_ts: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3292cf8220ffe..4662176a1cc51 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -35,7 +35,7 @@ import einops import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -75,7 +75,11 @@ from vllm.multimodal.evs import ( compute_retention_mask, recompute_mrope_positions, ) -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargs, +) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors @@ -1120,15 +1124,17 @@ class Qwen2_5_VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1165,20 +1171,12 @@ class Qwen2_5_VLForConditionalGeneration( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 61057fa145f47..bbebe7c0f9289 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -34,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( Qwen2VLConfig, @@ -70,6 +70,7 @@ from vllm.multimodal.inputs import ( ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1240,21 +1241,17 @@ class Qwen2VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get M-RoPE input positions for Qwen2-VL model.""" - if image_grid_thw is None: - image_grid_thw = [] - if video_grid_thw is None: - video_grid_thw = [] - if second_per_grid_ts is None: - second_per_grid_ts = [] + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1291,20 +1288,12 @@ class Qwen2VLForConditionalGeneration( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 468b25220154b..e6cb4442e2bef 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -65,7 +65,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -1414,39 +1414,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - config = hf_config.thinker_config - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) + input_ids = torch.tensor(input_tokens) if input_ids is None or input_ids.ndim != 1: raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") seq_len = input_ids.shape[0] - if audio_feature_lengths is not None and not isinstance( - audio_feature_lengths, torch.Tensor - ): - audio_feature_lengths = torch.as_tensor( + + if isinstance(audio_feature_lengths, list): + audio_feature_lengths = torch.tensor( audio_feature_lengths, dtype=torch.long ) - if second_per_grid_ts is None: - if video_grid_thw is not None and video_grid_thw.numel() > 0: - second_per_grids = torch.ones( - video_grid_thw.shape[0], dtype=torch.float32 - ) - else: - second_per_grids = torch.tensor([], dtype=torch.float32) + + if not len(second_per_grid_ts) and len(video_grid_thw): + second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32) else: second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + config = self.config spatial_merge_size = config.vision_config.spatial_merge_size image_token_id = config.image_token_id video_token_id = config.video_token_id diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1be35cde7dbdc..97d4667d82e99 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -34,7 +34,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, @@ -70,6 +70,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, @@ -1416,17 +1417,18 @@ class Qwen3VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1455,20 +1457,12 @@ class Qwen3VLForConditionalGeneration( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_videos -= 1 ed = ed_video diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 476074542e6ae..2efcef68d1c72 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -27,6 +27,7 @@ from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, @@ -38,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from transformers import BatchFeature, PretrainedConfig + from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -367,20 +368,34 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)): + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + if any( + v + for k, v in kwargs.items() + if k not in {"image_grid_thw", "video_grid_thw"} + ): raise NotImplementedError("Transformers backend only supports images.") - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index a05f54191f044..7518a023c5f50 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -249,6 +249,19 @@ class MultiModalFeatureSpec: mm_position: PlaceholderRange """e.g., PlaceholderRange(offset=2, length=336)""" + @staticmethod + def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): + kwargs = defaultdict[str, list[NestedTensors]](list) + + for f in features: + item = f.data + if item is not None: + for k in keys: + if k in item: + kwargs[k].append(item[k].data) + + return dict(kwargs) + @dataclass class MultiModalFieldElem: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 790649b69e5c9..fbd3e5f313167 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -892,38 +892,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _init_mrope_positions(self, req_state: CachedRequestState): - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + model = self.get_model() + assert supports_mrope(model), "M-RoPE support is not implemented." req_state.mrope_positions, req_state.mrope_position_delta = ( - self.model.get_mrope_input_positions( + model.get_mrope_input_positions( req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, + req_state.mm_features, ) ) From 3380543b2075abd6f3e6e283f4eacb307354e33a Mon Sep 17 00:00:00 2001 From: Ido Segev Date: Tue, 11 Nov 2025 15:41:18 +0200 Subject: [PATCH 12/98] Add request timeout override for multi-turn benchmarks (#28386) Signed-off-by: Ido Segev --- .../benchmark_serving_multi_turn.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 5d2ac66e5ab94..2c1a051cc9c97 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -63,6 +63,7 @@ class RequestArgs(NamedTuple): stream: bool limit_min_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit + timeout_sec: int class BenchmarkArgs(NamedTuple): @@ -214,6 +215,7 @@ async def send_request( stream: bool = True, min_tokens: int | None = None, max_tokens: int | None = None, + timeout_sec: int = 120, ) -> ServerResponse: payload = { "model": model, @@ -235,10 +237,16 @@ async def send_request( headers = {"Content-Type": "application/json"} # Calculate the timeout for the request - timeout_sec = 120 if max_tokens is not None: # Assume TPOT of 200ms and use max_tokens to determine timeout - timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + token_based_timeout = int(max_tokens * 0.2) + if token_based_timeout > timeout_sec: + timeout_sec = token_based_timeout + logger.info( + "Using timeout of %ds based on max_tokens %d", + timeout_sec, + max_tokens, + ) timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True @@ -409,6 +417,7 @@ async def send_turn( req_args.stream, min_tokens, max_tokens, + req_args.timeout_sec, ) if response.valid is False: @@ -676,8 +685,18 @@ async def client_main( except asyncio.exceptions.TimeoutError: num_failures += 1 - logger.exception( - f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + logger.error( + "%sClient %d - Timeout during conversation ID %s (turn: %d). " + "Base timeout is %ss (set with --request-timeout-sec), but the " + "effective timeout may be longer based on max_tokens. If this " + "is unexpected, consider increasing the timeout or checking " + "model performance.%s", + Color.RED, + client_id, + conv_id, + current_turn, + req_args.timeout_sec, + Color.RESET, ) break # Exit gracefully instead of raising an error @@ -815,6 +834,9 @@ def get_client_config( "Invalid min/max tokens limits (min should not be larger than max)" ) + if args.request_timeout_sec <= 0: + raise ValueError("Request timeout must be a positive number") + # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" model_name = args.served_model_name if args.served_model_name else args.model @@ -825,6 +847,7 @@ def get_client_config( stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, + timeout_sec=args.request_timeout_sec, ) return client_args, req_args @@ -968,7 +991,7 @@ async def main_mp( f"(is alive: {client.is_alive()}){Color.RESET}" ) - client.join(timeout=120) + client.join(timeout=req_args.timeout_sec + 1) if client.is_alive(): logger.warning( @@ -1351,6 +1374,13 @@ async def main() -> None: action="store_true", help="Verify the LLM output (compare to the answers in the input JSON file)", ) + parser.add_argument( + "--request-timeout-sec", + type=int, + default=120, + help="Timeout in seconds for each API request (default: 120). " + "Automatically increased if max tokens imply longer decoding.", + ) parser.add_argument( "--no-stream", From fa1970201d2efae6db48ca808ba50b63390457db Mon Sep 17 00:00:00 2001 From: Maryam Tahhan Date: Tue, 11 Nov 2025 14:01:11 +0000 Subject: [PATCH 13/98] [Docs] Fix grammar in CPU installation guide (#28461) Signed-off-by: Maryam Tahhan --- docs/getting_started/installation/cpu.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 2369eaed1802e..dbfefa9a1fe5a 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -93,7 +93,7 @@ Currently, there are no pre-built CPU wheels. ## Related runtime environment variables -- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. +- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM to run more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. - `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. @@ -128,7 +128,7 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe ### How to decide `VLLM_CPU_OMP_THREADS_BIND`? -- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following. +- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to the same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If you have any performance problems or unexpected binding behaviours, please try to bind threads as following. - On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: @@ -156,12 +156,12 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe 14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000 15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000 - # On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15 + # On this platform, it is recommended to only bind openMP threads on logical CPU cores 0-7 or 8-15 $ export VLLM_CPU_OMP_THREADS_BIND=0-7 $ python examples/offline_inference/basic/basic.py ``` -- When deploy vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on a same NUMA node to avoid cross NUMA node memory access. +- When deploying vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on the same NUMA node to avoid cross NUMA node memory access. ### How to decide `VLLM_CPU_KVCACHE_SPACE`? @@ -171,7 +171,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -192,8 +192,8 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? - Both of them require `amx` CPU flag. - - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. + - `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models + - `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios. ### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? From a1448b4b69b15c33b4fbc9a883c4f3b9559ee7db Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 11 Nov 2025 09:29:02 -0500 Subject: [PATCH 14/98] [Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064) --- .../moe/modular_kernel_tools/mk_objects.py | 9 +- vllm/lora/layers/fused_moe.py | 4 +- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/all2all_utils.py | 160 +++ .../layers/fused_moe/fused_moe_method_base.py | 112 +++ .../fused_moe/fused_moe_modular_method.py | 164 +++ vllm/model_executor/layers/fused_moe/layer.py | 950 +----------------- .../layers/fused_moe/shared_fused_moe.py | 2 +- .../fused_moe/unquantized_fused_moe_method.py | 578 +++++++++++ .../layers/quantization/mxfp4.py | 29 +- 10 files changed, 1064 insertions(+), 948 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/all2all_utils.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_method_base.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py create mode 100644 vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 21eeffb1c7264..d79fdfbe07af3 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -6,6 +6,10 @@ import torch # Fused experts and PrepareFinalize imports import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) @@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, NaiveBatchedExperts, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) @@ -399,9 +402,7 @@ def make_prepare_finalize( quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( - moe, quant_config - ) + prepare_finalize = maybe_make_prepare_finalize(moe, quant_config) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dadb9e25ba2f1..8fb3efa220f6d 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -25,7 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( modular_triton_fused_moe, try_get_optimal_moe_config, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoEModularMethod +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) class FusedMoEWithLoRA(BaseLayerWithLoRA): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index cb31045971bd8..53d98d0650b43 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -5,9 +5,11 @@ from contextlib import contextmanager from typing import Any from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, - FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py new file mode 100644 index 0000000000000..2dd625054339c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +from vllm.distributed import ( + get_ep_group, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEPrepareAndFinalize, +) +from vllm.platforms import current_platform +from vllm.utils.import_utils import has_deep_ep, has_pplx + +if current_platform.is_cuda_alike(): + if has_pplx(): + from .pplx_prepare_finalize import ( + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) + if has_deep_ep(): + from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize + from .deepep_ll_prepare_finalize import ( + DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize, + ) + + +def maybe_roundup_layer_hidden_size( + hidden_size: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, +) -> int: + """ + Given layer hidden size and MoE configurations, round up hidden_size + if necessary. + + Args: + hidden_size: Layer hidden-size + act_dtype: Data type of the layer activations. + moe_parallel_config: Fused MoE parallelization strategy configuration. + + Return: + Rounded up hidden_size if rounding up is required based on the configs + and all2all backend. + Original hidden size otherwise. + """ + if moe_parallel_config.use_deepep_ht_kernels: + hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype + ) + + if moe_parallel_config.use_deepep_ll_kernels: + hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + + return hidden_size + + +def maybe_make_prepare_finalize( + moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig | None, +) -> FusedMoEPrepareAndFinalize | None: + if not moe.moe_parallel_config.use_all2all_kernels: + return None + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + prepare_finalize: FusedMoEPrepareAndFinalize | None = None + + # TODO: could allow this now + assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" + + if moe.use_pplx_kernels: + assert quant_config is not None + + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, + ) + + num_dispatchers = ( + all2all_manager.world_size // all2all_manager.tp_group.world_size + ) + + # Intranode pplx a2a takes a group name while internode does not. + if not all2all_manager.internode: + all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name + + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + num_local_experts=moe.num_local_experts, + num_dispatchers=num_dispatchers, + ) + elif moe.use_deepep_ht_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + + all_to_all_args = dict() + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = DeepEPHTPrepareAndFinalize( + handle, + num_dispatchers=all2all_manager.world_size, + dp_size=all2all_manager.dp_world_size, + rank_expert_offset=all2all_manager.rank * moe.num_local_experts, + ) + + elif moe.use_deepep_ll_kernels: + assert quant_config is not None + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE + ) + + prepare_finalize = DeepEPLLPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + ) + + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py new file mode 100644 index 0000000000000..87f8c8d75a9b5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from collections.abc import Callable + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase, +) + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.moe: FusedMoEConfig = moe + self.moe_quant_config: FusedMoEQuantConfig | None = None + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def uses_weight_scale_2_pattern(self) -> bool: + """ + Returns True if this quantization method uses 'weight_scale_2' pattern + for per-tensor weight scales (e.g., FP4 variants), False otherwise. + + This method should be overridden by subclasses that use the + 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. + """ + return False + + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + from .all2all_utils import maybe_make_prepare_finalize + + return maybe_make_prepare_finalize(self.moe, self.moe_quant_config) + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + raise NotImplementedError( + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize" + ) + + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + raise NotImplementedError + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + @property + def supports_eplb(self) -> bool: + return False + + @property + def allow_inplace(self) -> bool: + return False + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py new file mode 100644 index 0000000000000..43974ba917e42 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel, + FusedMoEPrepareAndFinalize, +) + +logger = init_logger(__name__) + + +@CustomOp.register("modular_fused_moe") +class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): + def __init__( + self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel + ): + super().__init__(old_quant_method.moe) + self.moe_quant_config = old_quant_method.moe_quant_config + self.fused_experts = experts + self.disable_expert_map = getattr( + old_quant_method, + "disable_expert_map", + not self.fused_experts.supports_expert_map(), + ) + self.old_quant_method = old_quant_method + logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) + + @staticmethod + def make( + moe_layer: torch.nn.Module, + old_quant_method: FusedMoEMethodBase, + prepare_finalize: FusedMoEPrepareAndFinalize, + shared_experts: torch.nn.Module | None, + ) -> "FusedMoEModularMethod": + return FusedMoEModularMethod( + old_quant_method, + FusedMoEModularKernel( + prepare_finalize, + old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), + shared_experts, + ), + ) + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return self.fused_experts.prepare_finalize.topk_indices_dtype() + + @property + def supports_eplb(self) -> bool: + return self.old_quant_method.supports_eplb + + @property + def allow_inplace(self) -> bool: + return self.old_quant_method.allow_inplace + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return self.moe_quant_config + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Is getattr needed? + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + if enable_eplb: + if self.supports_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + else: + raise NotImplementedError( + "EPLB is not supported for " + f"{self.old_quant_method.__class__.__name__}." + ) + + topk_weights, topk_ids, zero_expert_result = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + ) + + result = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=self.allow_inplace, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=None if self.disable_expert_map else expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 39547cc83c7b6..e198322ba7a89 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import abstractmethod from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum @@ -27,17 +26,13 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, RoutingMethodType, - biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, - FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) @@ -47,35 +42,17 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, - QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( is_flashinfer_supporting_global_sf, ) -from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.platforms.interface import CpuArchEnum -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.import_utils import has_deep_ep, has_pplx from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): - from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, eplb_map_to_physical_and_record, fused_experts - - if has_pplx(): - from .pplx_prepare_finalize import ( - PplxPrepareAndFinalize, - pplx_hidden_dim_scale_bytes, - ) - if has_deep_ep(): - from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import ( - DEEPEP_QUANT_BLOCK_SHAPE, - DeepEPLLPrepareAndFinalize, - ) + from .fused_moe import eplb_map_to_physical_and_record, fused_experts else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = object # type: ignore @@ -102,6 +79,16 @@ if current_platform.is_tpu(): else: fused_moe_pallas = None # type: ignore +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) + logger = init_logger(__name__) @@ -112,885 +99,6 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -class FusedMoEMethodBase(QuantizeMethodBase): - def __init__(self, moe: FusedMoEConfig): - super().__init__() - self.moe: FusedMoEConfig = moe - self.moe_quant_config: FusedMoEQuantConfig | None = None - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - def uses_weight_scale_2_pattern(self) -> bool: - """ - Returns True if this quantization method uses 'weight_scale_2' pattern - for per-tensor weight scales (e.g., FP4 variants), False otherwise. - - This method should be overridden by subclasses that use the - 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. - """ - return False - - @staticmethod - def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, - quant_config: FusedMoEQuantConfig | None, - ) -> FusedMoEPrepareAndFinalize | None: - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - - prepare_finalize: FusedMoEPrepareAndFinalize | None = None - - # TODO: could allow this now - assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" - - if moe.use_pplx_kernels: - assert quant_config is not None - - hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( - moe.max_num_tokens, - moe.hidden_dim, - moe.in_dtype, - quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - - all_to_all_args = dict( - max_num_tokens=moe.max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=all2all_manager.rank, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=hidden_dim_bytes, - hidden_dim_scale_bytes=hidden_scale_bytes, - ) - - num_dispatchers = ( - all2all_manager.world_size // all2all_manager.tp_group.world_size - ) - - # Intranode pplx a2a takes a group name while internode does not. - if not all2all_manager.internode: - all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name - - handle = all2all_manager.get_handle(all_to_all_args) - - prepare_finalize = PplxPrepareAndFinalize( - handle, - max_num_tokens=moe.max_num_tokens, - num_local_experts=moe.num_local_experts, - num_dispatchers=num_dispatchers, - ) - elif moe.use_deepep_ht_kernels: - assert moe.dp_size == all2all_manager.dp_world_size - - all_to_all_args = dict() - handle = all2all_manager.get_handle(all_to_all_args) - prepare_finalize = DeepEPHTPrepareAndFinalize( - handle, - num_dispatchers=all2all_manager.world_size, - dp_size=all2all_manager.dp_world_size, - rank_expert_offset=all2all_manager.rank * moe.num_local_experts, - ) - - elif moe.use_deepep_ll_kernels: - assert quant_config is not None - all_to_all_args = dict( - max_num_tokens_per_dp_rank=moe.max_num_tokens, - token_hidden_size=moe.hidden_dim, - num_ep_ranks=all2all_manager.world_size, - num_global_experts=moe.num_experts, - num_local_experts=moe.num_experts // all2all_manager.world_size, - ) - handle = all2all_manager.get_handle(all_to_all_args) - - # Note: We may want to use FP8 dispatch just to reduce - # data movement. - use_fp8_dispatch = ( - quant_config.quant_dtype == current_platform.fp8_dtype() - and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE - ) - - prepare_finalize = DeepEPLLPrepareAndFinalize( - handle, - max_tokens_per_rank=moe.max_num_tokens, - num_dispatchers=all2all_manager.world_size, - use_fp8_dispatch=use_fp8_dispatch, - ) - - return prepare_finalize - - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: - if self.moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize( - self.moe, self.moe_quant_config - ) - else: - return None - - def maybe_init_modular_kernel( - self, layer: torch.nn.Module - ) -> FusedMoEModularKernel | None: - assert self.moe is not None - - # We must get the quant config here so that the layer is - # completely initialized, i.e. all weights loaded and post - # processed. - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - - prepare_finalize = self.maybe_make_prepare_finalize() - - if prepare_finalize is not None: - logger.debug( - "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) - ) - experts = self.select_gemm_impl(prepare_finalize, layer) - return FusedMoEModularKernel( - prepare_finalize, - experts, - layer.shared_experts, - ) - else: - return None - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - # based on the all2all implementation, select the appropriate - # gemm implementation - raise NotImplementedError( - f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize" - ) - - @abstractmethod - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - raise NotImplementedError - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - @property - def supports_eplb(self) -> bool: - return False - - @property - def allow_inplace(self) -> bool: - return False - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - -@CustomOp.register("modular_fused_moe") -class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): - def __init__( - self, - old_quant_method: FusedMoEMethodBase, - fused_experts: FusedMoEModularKernel, - ): - super().__init__(old_quant_method.moe) - # Find better way to copy attributes? Should we even copy attributes? - # self.__dict__.update(old_quant_method.__dict__) - self.moe_quant_config = old_quant_method.moe_quant_config - self.fused_experts = fused_experts - self.disable_expert_map = getattr( - old_quant_method, - "disable_expert_map", - not fused_experts.supports_expert_map(), - ) - self.old_quant_method = old_quant_method - logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - return self.fused_experts.prepare_finalize.topk_indices_dtype() - - @property - def supports_eplb(self) -> bool: - return self.old_quant_method.supports_eplb - - @property - def allow_inplace(self) -> bool: - return self.old_quant_method.allow_inplace - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return self.moe_quant_config - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Is getattr needed? - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - if enable_eplb: - if self.supports_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - else: - raise NotImplementedError( - "EPLB is not supported for " - f"{self.old_quant_method.__class__.__name__}." - ) - - topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - ) - - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=self.allow_inplace, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=None if self.disable_expert_map else expert_map, - ) - - if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result - - -@CustomOp.register("unquantized_fused_moe") -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): - """MoE method without quantization.""" - - def __init__(self, moe: FusedMoEConfig): - super().__init__(moe) - - self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - if self.rocm_aiter_moe_enabled: - from .rocm_aiter_fused_moe import rocm_aiter_fused_experts - - self.rocm_aiter_fused_experts = rocm_aiter_fused_experts - else: - self.rocm_aiter_fused_experts = None # type: ignore - - # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS - self.flashinfer_cutlass_moe_enabled = ( - has_flashinfer_cutlass_fused_moe() - and envs.VLLM_USE_FLASHINFER_MOE_FP16 - and self.moe.moe_parallel_config.use_ep - and self.moe.moe_parallel_config.dp_size == 1 - and current_platform.get_device_capability()[0] >= 9 - ) - if self.flashinfer_cutlass_moe_enabled: - logger.info_once( - "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" - ) - from functools import partial - - from .flashinfer_cutlass_moe import flashinfer_cutlass_moe - - self.flashinfer_cutlass_moe = partial( - flashinfer_cutlass_moe, - quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, - tp_rank=self.moe.moe_parallel_config.tp_rank, - tp_size=self.moe.moe_parallel_config.tp_size, - ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size, - ) - else: - if ( - self.moe.moe_parallel_config.use_ep - and self.moe.moe_parallel_config.dp_size == 1 - ): - logger.info_once( - "FlashInfer CUTLASS MoE is available for EP" - " but not enabled, consider setting" - " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", - scope="local", - ) - elif self.moe.moe_parallel_config.dp_size > 1: - logger.info_once( - "FlashInfer CUTLASS MoE is currently not available for DP.", - scope="local", - ) - self.flashinfer_cutlass_moe = None # type: ignore - - @property - def supports_eplb(self) -> bool: - return True - - @property - def allow_inplace(self) -> bool: - return True - - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: - if self.rocm_aiter_moe_enabled: - return None - else: - return super().maybe_make_prepare_finalize() - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("BatchedTritonExperts %s", self.moe) - return BatchedTritonExperts( - max_num_tokens=self.moe.max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - else: - logger.debug("TritonExperts %s", self.moe) - return TritonExperts(self.moe_quant_config) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - if self.moe.is_act_and_mul: - w13_up_dim = 2 * intermediate_size_per_partition - else: - w13_up_dim = intermediate_size_per_partition - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_up_dim, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - if self.moe.has_bias: - w13_bias = torch.nn.Parameter( - torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - if self.moe.has_bias: - w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if ( - envs.VLLM_ROCM_MOE_PADDING - and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0 - ): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - - return weight - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - - # Padding the weight for better performance on ROCm - layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) - layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - - if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - layer.w13_weight.data = shuffled_w13 - layer.w2_weight.data = shuffled_w2 - - if self.flashinfer_cutlass_moe_enabled: - # Swap halves to arrange as [w3; w1] (kernel expectation) - w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) - w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) - layer.w13_weight.data = w13_weight_swapped.contiguous() - - if current_platform.is_xpu(): - import intel_extension_for_pytorch as ipex - - ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=True, - experts_start_id=ep_rank_start, - ) - elif current_platform.is_cpu(): - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - - if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.utils import check_cpu_sgl_kernel - - dtype_w13 = layer.w13_weight.dtype - _, n_w13, k_w13 = layer.w13_weight.size() - dtype_w2 = layer.w2_weight.dtype - _, n_w2, k_w2 = layer.w2_weight.size() - if ( - envs.VLLM_CPU_SGL_KERNEL - and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) - and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) - ): - packed_w13_weight = torch.ops._C.convert_weight_packed( - layer.w13_weight - ) - assert packed_w13_weight.size() == layer.w13_weight.size() - layer.w13_weight.copy_(packed_w13_weight) - del packed_w13_weight - packed_w2_weight = torch.ops._C.convert_weight_packed( - layer.w2_weight - ) - assert packed_w2_weight.size() == layer.w2_weight.size() - layer.w2_weight.copy_(packed_w2_weight) - layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) - else: - layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) - else: - layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - - return self.forward( - x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - global_num_experts=global_num_experts, - expert_map=expert_map, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - enable_eplb=enable_eplb, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - if self.moe.has_bias: - return biased_moe_quant_config( - layer.w13_bias, - layer.w2_bias, - ) - else: - return FUSED_MOE_UNQUANTIZED_CONFIG - - def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - num_fused_shared_experts=layer.num_fused_shared_experts, - ) - - if self.rocm_aiter_moe_enabled: - result = self.rocm_aiter_fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - elif self.flashinfer_cutlass_moe_enabled: - return self.flashinfer_cutlass_moe( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - result = fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) - - if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result - - def forward_cpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for CPU.") - return layer.cpu_fused_moe( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - routed_scaling_factor, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - ) - - def forward_xpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for XPU.") - return layer.ipex_fusion( - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - custom_routing_function=custom_routing_function, - ) - - def forward_tpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None - assert custom_routing_function is None - assert apply_router_weight_on_input is False - if scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for TPU." - ) - if e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for TPU." - ) - assert activation == "silu", f"{activation} is not supported for TPU." - assert routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." - ) - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for TPU.") - return fused_moe_pallas( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=top_k, - gating_output=router_logits, - global_num_experts=global_num_experts, - expert_map=expert_map, - renormalize=renormalize, - ) - - if current_platform.is_tpu(): - forward_native = forward_tpu - elif current_platform.is_cpu(): - forward_native = forward_cpu - elif current_platform.is_xpu(): - forward_native = forward_xpu - else: - forward_native = forward_cuda - - def determine_expert_map( ep_size: int, ep_rank: int, @@ -1125,16 +233,13 @@ def maybe_roundup_hidden_size( Rounded up hidden_size if rounding up is required based on the configs. Original hidden size otherwise. """ + from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_roundup_layer_hidden_size, + ) - if moe_parallel_config.use_deepep_ht_kernels: - hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( - hidden_size, act_dtype - ) - - if moe_parallel_config.use_deepep_ll_kernels: - hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( - hidden_size - ) + hidden_size = maybe_roundup_layer_hidden_size( + hidden_size, act_dtype, moe_parallel_config + ) # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": @@ -1430,7 +535,6 @@ class FusedMoE(CustomOp): is_lora_enabled=vllm_config.lora_config is not None, ) - self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config def _get_quant_method() -> FusedMoEMethodBase: @@ -1508,9 +612,15 @@ class FusedMoE(CustomOp): # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: - mk = self.quant_method.maybe_init_modular_kernel(self) - if mk is not None: - self.quant_method = FusedMoEModularMethod(self.quant_method, mk) + self.ensure_moe_quant_config_init() + prepare_finalize = self.quant_method.maybe_make_prepare_finalize() + if prepare_finalize is not None: + logger.debug( + "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) + ) + self.quant_method = FusedMoEModularMethod.make( + self, self.quant_method, prepare_finalize, self.shared_experts + ) @property def shared_experts(self) -> torch.nn.Module | None: @@ -2142,12 +1252,16 @@ class FusedMoE(CustomOp): def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: + # Note: the moe_quant_config can't be constructed until after + # weight loading post processing. self.quant_method.moe_quant_config = ( self.quant_method.get_fused_moe_quant_config(self) ) - if self.moe_quant_config is None: - self.moe_quant_config = self.quant_method.moe_quant_config + @property + def moe_quant_config(self) -> FusedMoEQuantConfig | None: + self.ensure_moe_quant_config_init() + return self.quant_method.moe_quant_config def ensure_dp_chunking_init(self): if not self.use_dp_chunking or self.batched_hidden_states is not None: diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 3d0c5636d6c0a..06112ca51b6d5 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -38,7 +38,7 @@ class SharedFusedMoE(FusedMoE): and not ( # TODO(wentao): find the root cause and remove this condition self.enable_eplb - or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py new file mode 100644 index 0000000000000..ce56887f1c26d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEQuantConfig, + biased_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if current_platform.is_cuda_alike(): + from .fused_batched_moe import BatchedTritonExperts + from .fused_moe import TritonExperts, fused_experts +else: + fused_experts = None # type: ignore + +if current_platform.is_tpu(): + from .moe_pallas import fused_moe as fused_moe_pallas +else: + fused_moe_pallas = None # type: ignore + +logger = init_logger(__name__) + + +@CustomOp.register("unquantized_fused_moe") +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if self.rocm_aiter_moe_enabled: + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts + else: + self.rocm_aiter_fused_experts = None # type: ignore + + # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS + self.flashinfer_cutlass_moe_enabled = ( + has_flashinfer_cutlass_fused_moe() + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + and current_platform.get_device_capability()[0] >= 9 + ) + if self.flashinfer_cutlass_moe_enabled: + logger.info_once( + "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" + ) + from functools import partial + + from .flashinfer_cutlass_moe import flashinfer_cutlass_moe + + self.flashinfer_cutlass_moe = partial( + flashinfer_cutlass_moe, + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + tp_rank=self.moe.moe_parallel_config.tp_rank, + tp_size=self.moe.moe_parallel_config.tp_size, + ep_rank=self.moe.moe_parallel_config.ep_rank, + ep_size=self.moe.moe_parallel_config.ep_size, + ) + else: + if ( + self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + ): + logger.info_once( + "FlashInfer CUTLASS MoE is available for EP" + " but not enabled, consider setting" + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", + scope="local", + ) + elif self.moe.moe_parallel_config.dp_size > 1: + logger.info_once( + "FlashInfer CUTLASS MoE is currently not available for DP.", + scope="local", + ) + self.flashinfer_cutlass_moe = None # type: ignore + + @property + def supports_eplb(self) -> bool: + return True + + @property + def allow_inplace(self) -> bool: + return True + + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + logger.debug("BatchedTritonExperts %s", self.moe) + return BatchedTritonExperts( + max_num_tokens=self.moe.max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + logger.debug("TritonExperts %s", self.moe) + return TritonExperts(self.moe_quant_config) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_up_dim, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.moe.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if ( + envs.VLLM_ROCM_MOE_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + + return weight + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) + + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 + + if self.flashinfer_cutlass_moe_enabled: + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + layer.w13_weight.data = w13_weight_swapped.contiguous() + + if current_platform.is_xpu(): + import intel_extension_for_pytorch as ipex + + ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=True, + experts_start_id=ep_rank_start, + ) + elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel + + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() + if ( + envs.VLLM_CPU_SGL_KERNEL + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) + ): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight + ) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight + ) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + topk_weights, topk_ids, zero_expert_result = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, + ) + + if self.rocm_aiter_moe_enabled: + result = self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_cutlass_moe_enabled: + return self.flashinfer_cutlass_moe( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + result = fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for CPU.") + return layer.cpu_fused_moe( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + routed_scaling_factor, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + ) + + def forward_xpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for XPU.") + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function=custom_routing_function, + ) + + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU." + ) + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for TPU." + ) + assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." + ) + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for TPU.") + return fused_moe_pallas( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize, + ) + + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu + else: + forward_native = forward_cuda diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index e339f15510d79..4e51249f2d25b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) ) - self.w13_weight_triton_tensor = w13_weight - self.w2_weight_triton_tensor = w2_weight - - # need to delete the original weights to save memory on single GPU - del layer.w13_weight - del layer.w2_weight - layer.w13_weight = None - layer.w2_weight = None - torch.cuda.empty_cache() + self.w13_weight = w13_weight + self.w2_weight = w2_weight + layer.w13_weight = w13_weight + layer.w2_weight = w2_weight else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") @@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): "EP batched experts format" ) else: - layer.w13_weight = ( - self.w13_weight_triton_tensor - if layer.w13_weight is None - else layer.w13_weight - ) - layer.w2_weight = ( - self.w2_weight_triton_tensor - if layer.w2_weight is None - else layer.w2_weight - ) - assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) - assert self.moe_quant_config is not None if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM @@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return triton_kernel_moe_forward( hidden_states=x, - w1=self.w13_weight_triton_tensor, - w2=self.w2_weight_triton_tensor, + w1=self.w13_weight, + w2=self.w2_weight, gating_output=router_logits, topk=top_k, renormalize=renormalize, From 533b018f725fb9c2421e2c4b5a48d62fa5f1d844 Mon Sep 17 00:00:00 2001 From: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Date: Tue, 11 Nov 2025 15:41:43 +0100 Subject: [PATCH 15/98] [BugFix] Fix Failing Ruff Check (#28469) Signed-off-by: Jan van Lunteren --- tests/compile/test_fusions_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f67063cdf42ea..e1560efb3f247 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -74,7 +74,7 @@ if current_platform.is_cuda(): ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=97, ), From a90ad7d838b446cfc2dd7b4252086e13c3a8abbf Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 11 Nov 2025 15:03:22 +0000 Subject: [PATCH 16/98] Add @markmc to CODEOWNERS for Observability (#28457) Signed-off-by: Mark McLoughlin --- .github/CODEOWNERS | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 23def076cf880..f26c782bccf2c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -61,6 +61,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor +# Observability +/vllm/config/observability.py @markmc +/vllm/v1/metrics @markmc +/tests/v1/metrics @markmc +/vllm/tracing.py @markmc +/tests/v1/tracing/test_tracing.py @markmc +/vllm/config/kv_events.py @markmc +/vllm/distributed/kv_events.py @markmc +/tests/distributed/test_events.py @markmc + # Docs /docs/mkdocs @hmellor /docs/**/*.yml @hmellor From b886068056a05857f796909d2f8573b36fc668a5 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Nov 2025 23:29:33 +0800 Subject: [PATCH 17/98] [BugFix] Fix RuntimeError in PixtralHFAttention on CPU/XPU (#28444) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0555717017cdc..dfe5f0c52a505 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1109,7 +1109,7 @@ class PixtralHFAttention(nn.Module): ) out = out.transpose(1, 2) - out = out.view(batch, patches, self.n_heads * self.head_dim) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) attn_output, _ = self.o_proj(out) return attn_output, None From 3143eb23fc4e017bc31d11a9756d5a788d6f7e33 Mon Sep 17 00:00:00 2001 From: usberkeley <150880684+usberkeley@users.noreply.github.com> Date: Wed, 12 Nov 2025 00:01:30 +0800 Subject: [PATCH 18/98] [BugFix] Add test_outputs.py to CI pipeline (#28466) Signed-off-by: Bradley Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .buildkite/test-amd.yaml | 1 + .buildkite/test-pipeline.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index bb5ef5d624630..5fd048c2ad0c6 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -348,6 +348,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 83a7df3b093fc..25f711dd60b37 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -329,6 +329,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine From 287bbbeb067cd9e16ea9b834b35b47258a8ad43f Mon Sep 17 00:00:00 2001 From: the-codeboy <71213855+the-codeboy@users.noreply.github.com> Date: Tue, 11 Nov 2025 17:45:49 +0100 Subject: [PATCH 19/98] [Doc] Fix typo in serving docs (#28474) Signed-off-by: the-codeboy <71213855+the-codeboy@users.noreply.github.com> --- docs/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index e331b3422ea64..821628e6e3174 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -77,11 +77,11 @@ In addition, we have the following custom APIs: In order for the language model to support chat protocol, vLLM requires the model to include a chat template in its tokenizer configuration. The chat template is a Jinja2 template that -specifies how are roles, messages, and other chat-specific tokens are encoded in the input. +specifies how roles, messages, and other chat-specific tokens are encoded in the input. An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) -Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, +Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those models, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat template, or the template in string form. Without a chat template, the server will not be able to process chat and all chat requests will error. From f9a4087182ffcd9404779fcda876f820b3b26d5f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 09:46:04 -0700 Subject: [PATCH 20/98] Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431) Signed-off-by: mgoin --- benchmarks/kernels/bench_block_fp8_gemm.py | 43 +++++++++++++------ .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 3 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 22 ++-------- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index f1e504499eaf6..11e3ac7f0c1fa 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +# Disable DeepGEMM for this benchmark to use CUTLASS +os.environ["VLLM_USE_DEEP_GEMM"] = "0" + import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, @@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - # Create random FP8 tensors + # Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp) A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + # Create quantized weight tensor B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - # Create scales + # Create weight scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k @@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): * factor_for_scale ) - # SM90 CUTLASS requires row-major format for scales - if use_cutlass and current_platform.is_device_capability(90): - Bs = Bs.T.contiguous() + # Create W8A8BlockFp8LinearOp instance + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=use_cutlass, + use_aiter_and_is_supported=False, + ) def run(): - if use_cutlass: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True - ) - else: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False - ) + return linear_op.apply( + input=A_ref, + weight=B, + weight_scale=Bs, + input_scale=None, + bias=None, + ) return run diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index 147eb8efc0778..c40d499662714 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise { using ElementBlockScale = float; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< - ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::GMMA::Major::MN, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6da136cbc8f69..ee99572f5f499 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): layer.input_scale = None if self.strategy == QuantizationStrategy.BLOCK: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 83d136600b77c..cb065eb68b66b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase): return if self.block_quant: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c63196b893574..0c54cf4def005 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -55,17 +55,13 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, - is_hopper: bool | None = None, ) -> torch.Tensor: - if is_hopper is None: - is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, - # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None and is_hopper else Bs.T, + scale_b=Bs.T, ) @@ -130,7 +126,7 @@ def _padded_cutlass( padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) output = cutlass_scaled_mm( - padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype ) return output[0 : qx.shape[0], ...] @@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp: weight_scale, list(self.weight_group_shape), input_2d.dtype, - False, ) def _run_aiter( @@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy( return weight, weight_scale -def maybe_post_process_fp8_weight_block( - layer: torch.nn.Module, cutlass_block_fp8_supported: bool -): +def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): assert layer.weight_block_size is not None from vllm.utils.deep_gemm import ( @@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block( requant_weight_ue8m0_inplace( layer.weight.data, layer.weight_scale.data, block_sz ) - # SM90 Block FP8 CUTLASS requires row-major weight scales - elif ( - current_platform.is_device_capability(90) - and cutlass_block_fp8_supported - and not should_use_deepgemm - ): - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data.T.contiguous(), requires_grad=False - ) def expert_weight_is_col_major(x: torch.Tensor) -> bool: From a7ef3eb0cd03e729c7a29914400e0ca928767999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 11 Nov 2025 17:57:43 +0100 Subject: [PATCH 21/98] [NIXL] Generalize block-first backend layouts (FlashInfer-like) (#28282) --- .../kv_connector/unit/test_nixl_connector.py | 17 ++++++- .../kv_connector/v1/nixl_connector.py | 47 +++++++++++++++---- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 475cf2285e394..8e421717fea30 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): llm.llm_engine.engine_core.shutdown() -def test_register_kv_caches(dist_init): +@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"]) +def test_register_kv_caches(dist_init, attn_backend, monkeypatch): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init): block layout info """ + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + vllm_config = create_vllm_config() + # Import the appropriate backend based on the parameter + if attn_backend == "FLASH_ATTN": + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + backend_cls = FlashAttentionBackend + else: # TRITON_ATTN + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + + backend_cls = TritonAttentionBackend + # Create test kv cache tensors using proper backend shape - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = backend_cls.get_kv_cache_shape( num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6c20eee1ecbf9..375ea79d0e817 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,6 +21,7 @@ import torch import zmq from vllm import envs +from vllm.attention import AttentionBackend from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -669,6 +670,33 @@ class NixlConnectorWorker: remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int + attn_backend: type[AttentionBackend] + + def __post_init__(self): + # Figure out whether the first dimension of the cache is K/V + # or num_blocks. This is used to register the memory regions correctly. + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], + # we just mock num_blocks to 1 for the dimension check below. + self._is_kv_layout_blocks_first = ( + len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 + ) + + attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + + @property + def is_kv_layout_blocks_first(self) -> bool: + return self._is_kv_layout_blocks_first + + @property + def split_k_and_v(self) -> bool: + # Whether to register regions for K and V separately (when present). + return not ( + self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first + ) def tp_ratio( self, @@ -876,9 +904,6 @@ class NixlConnectorWorker: use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = AttentionBackendEnum[self.backend_name] - self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) @@ -896,7 +921,9 @@ class NixlConnectorWorker: remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, ) + self._use_pallas = self.kv_topo._use_pallas def _nixl_handshake( self, @@ -1076,7 +1103,7 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() @@ -1141,7 +1168,7 @@ class NixlConnectorWorker: self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 @@ -1169,7 +1196,7 @@ class NixlConnectorWorker: # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. @@ -1331,7 +1358,7 @@ class NixlConnectorWorker: # (addr, len, device id) blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1414,7 +1441,7 @@ class NixlConnectorWorker: remote_block_size = remote_block_len // ( self.slot_size_per_layer[0] * tp_ratio ) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -1494,7 +1521,7 @@ class NixlConnectorWorker: - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v inv_order = [0, 2, 1, 3] sample_cache = list(self.device_kv_caches.values())[0][0] target_shape = list(sample_cache.shape) @@ -1874,7 +1901,7 @@ class NixlConnectorWorker: For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2 else: From 68c09efc37e87032640cf8db571eaf486bd744ac Mon Sep 17 00:00:00 2001 From: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Date: Wed, 12 Nov 2025 01:00:31 +0800 Subject: [PATCH 22/98] [Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165) Signed-off-by: zhuhaoran --- .buildkite/test-pipeline.yaml | 1 + CMakeLists.txt | 1 + csrc/fused_qknorm_rope_kernel.cu | 418 ++++++++++++++++++ csrc/ops.h | 6 + csrc/torch_bindings.cpp | 10 + csrc/type_convert.cuh | 60 ++- tests/compile/test_qk_norm_rope_fusion.py | 195 ++++++++ tests/kernels/core/test_fused_qk_norm_rope.py | 141 ++++++ vllm/_custom_ops.py | 29 ++ vllm/compilation/fix_functionalization.py | 17 + vllm/compilation/fusion.py | 4 + vllm/compilation/matcher_utils.py | 81 +++- vllm/compilation/pass_manager.py | 4 + vllm/compilation/qk_norm_rope_fusion.py | 238 ++++++++++ vllm/config/compilation.py | 13 + .../layers/rotary_embedding/base.py | 63 ++- 16 files changed, 1243 insertions(+), 38 deletions(-) create mode 100644 csrc/fused_qknorm_rope_kernel.cu create mode 100644 tests/compile/test_qk_norm_rope_fusion.py create mode 100644 tests/kernels/core/test_fused_qk_norm_rope.py create mode 100644 vllm/compilation/qk_norm_rope_fusion.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f711dd60b37..8d2a7bc5a8029 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -451,6 +451,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_qk_norm_rope_fusion.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e9fa63b178ea..5cddf81a4b4aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/cuda_view.cu" diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu new file mode 100644 index 0000000000000..cbd23975a7739 --- /dev/null +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + +#define CHECK_TYPE(x, st) \ + TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ + ", while ", st, " is expected") +#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_TH_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define FINAL_MASK 0xffffffff + +// TODO: suport for AMD ROCM platform +#ifndef USE_ROCM +namespace tensorrt_llm::common { +template +struct packed_as; +// Specialization for packed_as used in this kernel. +template <> +struct packed_as { + using type = uint; +}; + +template <> +struct packed_as { + using type = uint2; +}; + +template <> +struct packed_as { + using type = uint4; +}; + +template +__inline__ __device__ T warpReduceSum(T val) { + #pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +inline __device__ __host__ T divUp(T m, T n) { + return (m + n - 1) / n; +} + +} // namespace tensorrt_llm::common + +namespace tensorrt_llm::kernels { +// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation, +// with added support for passing the cos_sin_cache as an input. +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu + +// Perform per-head QK Norm and RoPE in a single kernel. +// scalar_t_in: data type of QKV and RMSNorm weights +// scalar_t_cache: data type of cos/sin cache +// head_dim: the dimension of each head +// interleave: interleave=!is_neox. +template +__global__ void fusedQKNormRopeKernel( + void* qkv_void, // Combined QKV tensor + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + float const eps, // Epsilon for RMS normalization + void const* q_weight_void, // RMSNorm weights for query + void const* k_weight_void, // RMSNorm weights for key + void const* cos_sin_cache_void, // Pre-computed cos/sin cache + int64_t const* position_ids, // Position IDs for RoPE + int const num_tokens // Number of tokens +) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr ((std::is_same_v) || + std::is_same_v) { + return; + } else { + #endif + + using Converter = vllm::_typeConvert; + static_assert(Converter::exists, + "Input QKV data type is not supported for this CUDA " + "architecture or toolkit version."); + using T_in = typename Converter::hip_type; + using T2_in = typename Converter::packed_hip_type; + + using CacheConverter = vllm::_typeConvert; + static_assert(CacheConverter::exists, + "Cache data type is not supported for this CUDA architecture " + "or toolkit version."); + using T_cache = typename CacheConverter::hip_type; + + T_in* qkv = reinterpret_cast(qkv_void); + T_in const* q_weight = reinterpret_cast(q_weight_void); + T_in const* k_weight = reinterpret_cast(k_weight_void); + T_cache const* cos_sin_cache = + reinterpret_cast(cos_sin_cache_void); + + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + // Calculate global warp index to determine which head/token this warp + // processes + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + // Total number of attention heads (Q and K) + int const total_qk_heads = num_heads_q + num_heads_k; + + // Determine which token and head type (Q or K) this warp processes + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + // Skip if this warp is assigned beyond the number of tokens + if (tokenIdx >= num_tokens) return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one " + "head, and each thread gets even number of " + "elements)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, + "numSizeBytes must be a multiple of 4"); + constexpr int vecSize = + elemSizeBytes / + 4; // Use packed_as to perform loading/saving. + using vec_T = typename tensorrt_llm::common::packed_as::type; + + int offsetWarp; // Offset for the warp + if (isQ) { + // Q segment: token offset + head offset within Q segment + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } else { + // K segment: token offset + entire Q segment + head offset within K + // segment + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // Sum of squares for RMSNorm + float sumOfSquares = 0.0f; + + // Load. + { + vec_T vec = *reinterpret_cast(&qkv[offsetThread]); + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Interpret the generic vector chunk as the specific packed type + T2_in packed_val = *(reinterpret_cast(&vec) + i); + // Convert to float2 for computation + float2 vals = Converter::convert(packed_val); + sumOfSquares += vals.x * vals.x; + sumOfSquares += vals.y * vals.y; + + elements[2 * i] = vals.x; + elements[2 * i + 1] = vals.y; + } + } + + // Reduce sum across warp using the utility function + sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares); + + // Compute RMS normalization factor + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + + // Normalize elements + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + int dim = laneId * numElemsPerThread + i; + float weight = isQ ? Converter::convert(q_weight[dim]) + : Converter::convert(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // Apply RoPE to normalized elements + float elements2[numElemsPerThread]; // Additional buffer required for RoPE. + + int64_t pos_id = position_ids[tokenIdx]; + + // Calculate cache pointer for this position - similar to + // pos_encoding_kernels.cu + T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim; + int const embed_dim = head_dim / 2; + T_cache const* cos_ptr = cache_ptr; + T_cache const* sin_ptr = cache_ptr + embed_dim; + + if constexpr (interleave) { + // Perform interleaving. Use pre-computed cos/sin values. + #pragma unroll + for (int i = 0; i < numElemsPerThread / 2; ++i) { + int const idx0 = 2 * i; + int const idx1 = 2 * i + 1; + + float const val0 = elements[idx0]; + float const val1 = elements[idx1]; + + int const dim_idx = laneId * numElemsPerThread + idx0; + int const half_dim = dim_idx / 2; + float const cos_val = + CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float const sin_val = + CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[idx0] = val0 * cos_val - val1 * sin_val; + elements[idx1] = val0 * sin_val + val1 * cos_val; + } + } else { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + // Get the data from the other half of the warp. Use pre-computed cos/sin + // values. + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + if (laneId < 16) { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + dim_idx = (dim_idx * 2) % head_dim; + int half_dim = dim_idx / 2; + // Use pre-computed cos/sin from cache + float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); + } + + // Store. + { + vec_T vec; + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Convert from float2 back to the specific packed type + T2_in packed_val = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + // Place it into the generic vector + *(reinterpret_cast(&vec) + i) = packed_val; + } + *reinterpret_cast(&qkv[offsetThread]) = vec; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } + #endif +} + + // Borrowed from + // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 + #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +template +void launchFusedQKNormRope(void* qkv, int const num_tokens, + int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, + float const eps, void const* q_weight, + void const* k_weight, void const* cos_sin_cache, + bool const interleave, int64_t const* position_ids, + cudaStream_t stream) { + constexpr int blockSize = 256; + + int const warpsPerBlock = blockSize / 32; + int const totalQKHeads = num_heads_q + num_heads_k; + int const totalWarps = num_tokens * totalQKHeads; + + int const gridSize = common::divUp(totalWarps, warpsPerBlock); + dim3 gridDim(gridSize); + dim3 blockDim(blockSize); + + switch (head_dim) { + case 64: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 128: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 256: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + default: + TORCH_CHECK(false, + "Unsupported head dimension for fusedQKNormRope: ", head_dim); + } +} +} // namespace tensorrt_llm::kernels + +void fused_qk_norm_rope( + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] +) { + // Input validation + CHECK_INPUT(qkv); + CHECK_INPUT(position_ids); + CHECK_INPUT(q_weight); + CHECK_INPUT(k_weight); + CHECK_INPUT(cos_sin_cache); + CHECK_TYPE(position_ids, torch::kInt64); + + TORCH_CHECK(qkv.dim() == 2, + "QKV tensor must be 2D: [num_tokens, " + "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); + TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + TORCH_CHECK(cos_sin_cache.dim() == 2, + "Cos/sin cache must be 2D: [max_position, head_dim]"); + TORCH_CHECK(q_weight.size(0) == head_dim, + "Query weights size must match head dimension"); + TORCH_CHECK(k_weight.size(0) == head_dim, + "Key weights size must match head dimension"); + TORCH_CHECK(cos_sin_cache.size(1) == head_dim, + "Cos/sin cache dimension must match head_dim"); + TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && + qkv.scalar_type() == k_weight.scalar_type(), + "qkv, q_weight and k_weight must have the same dtype"); + + int64_t num_tokens = qkv.size(0); + TORCH_CHECK(position_ids.size(0) == num_tokens, + "Number of tokens in position_ids must match QKV"); + + int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; + TORCH_CHECK( + qkv.size(1) == total_heads * head_dim, + "QKV tensor size must match total number of heads and head dimension"); + + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + + VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using qkv_scalar_t = scalar_t; + VLLM_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using cache_scalar_t = scalar_t; + tensorrt_llm::kernels::launchFusedQKNormRope( + qkv.data_ptr(), static_cast(num_tokens), + static_cast(num_heads_q), static_cast(num_heads_k), + static_cast(num_heads_v), static_cast(head_dim), + static_cast(eps), q_weight.data_ptr(), k_weight.data_ptr(), + cos_sin_cache.data_ptr(), !is_neox, + reinterpret_cast(position_ids.data_ptr()), + stream); + }); + }); +} + +#endif // not USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 3f5cb799b774c..f8bdc61aaa8ec 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, + int64_t num_heads_k, int64_t num_heads_v, + int64_t head_dim, double eps, torch::Tensor& q_weight, + torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, + bool is_neox, torch::Tensor& position_ids); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9c0f524dcab11..d4a69cbe7971d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); +#ifndef USE_ROCM + // Function for fused QK Norm and RoPE + ops.def( + "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " + "int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " + "bool is_neox, Tensor position_ids) -> ()"); + ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); +#endif + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 21b9d0ae515df..6da06f1e66cf5 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -29,6 +29,22 @@ struct _typeConvert { static constexpr bool exists = false; }; +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = float; + using packed_hip_type = float2; + using packed_hip_type4 = float4; // For 128-bit vectorization + + __device__ static __forceinline__ float convert(hip_type x) { return x; } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { + return x; + } + __device__ static __forceinline__ float4 convert(packed_hip_type4 x) { + return x; + } +}; + #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> @@ -37,14 +53,16 @@ struct _typeConvert { using hip_type = __half; using packed_hip_type = __half2; - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { + return __half2float(x); + } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; @@ -58,16 +76,16 @@ struct _typeConvert { using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; @@ -95,10 +113,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] += other.data[i]; + data[i + 1] += other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll @@ -111,10 +134,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] *= other.data[i]; + data[i + 1] *= other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py new file mode 100644 index 0000000000000..973123a3af920 --- /dev/null +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.compile.backend import TestBackend +from vllm.attention import Attention, AttentionType +from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.qk_norm_rope_fusion import ( + FUSED_QK_ROPE_OP, + QKNormRoPEFusionPass, +) +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +RSQRT_OP = torch.ops.aten.rsqrt.default +INDEX_SELECT_OP = torch.ops.aten.index.Tensor + + +class QKNormRoPETestModel(torch.nn.Module): + def __init__( + self, + *, + num_heads: int, + num_kv_heads: int, + head_dim: int, + eps: float, + is_neox: bool, + vllm_config: VllmConfig, + dtype: torch.dtype, + prefix: str = "model.layers.0.self_attn.attn", + ) -> None: + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.rotary_dim = head_dim + self.eps = eps + self.dtype = dtype + + # Register layer metadata for the fusion pass via Attention. + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=1.0 / self.head_dim**0.5, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix=prefix, + attn_type=AttentionType.DECODER, + ) + + self.q_norm = RMSNorm(self.head_dim, eps=self.eps) + self.k_norm = RMSNorm(self.head_dim, eps=self.eps) + self.rotary_emb = RotaryEmbedding( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position_embeddings=4096, + base=10000, + is_neox_style=is_neox, + dtype=self.dtype, + ) + self.enable_rms_norm_custom_op = self.q_norm.enabled() + self.enable_rope_custom_op = self.rotary_emb.enabled() + + def forward(self, qkv: torch.Tensor, positions: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + return q, k, v + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + ops = [] + if self.enable_rms_norm_custom_op: + ops.append(RMS_OP) + else: + ops.append(RSQRT_OP) + + if self.enable_rope_custom_op: + if self.rotary_emb.use_flashinfer: + ops.append(FLASHINFER_ROTARY_OP) + else: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [FUSED_QK_ROPE_OP] + + +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_rope_custom_op", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Only test on cuda platform", +) +def test_qk_norm_rope_fusion( + eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype +): + if not hasattr(torch.ops._C, "fused_qk_norm_rope"): + pytest.skip("fused_qk_norm_rope custom op not available") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + custom_ops: list[str] = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_rope_custom_op: + custom_ops.append("+rotary_embedding") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + enable_qk_norm_rope_fusion=True, + enable_noop=True, + ), + ), + ) + + num_heads, num_kv_heads, head_dim = 16, 4, 128 + T = 5 + + with set_current_vllm_config(vllm_config): + model = QKNormRoPETestModel( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + eps=eps, + is_neox=is_neox, + vllm_config=vllm_config, + dtype=dtype, + ) + + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = QKNormRoPEFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend_baseline = TestBackend(noop_pass, cleanup_pass) + + qkv = torch.randn(T, model.q_size + 2 * model.kv_size) + pos = torch.arange(T, dtype=torch.long, device=qkv.device) + qkv_unfused = qkv.clone() + pos_unfused = pos.clone() + + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + model_fused = torch.compile(model, backend=backend) + q_fused, k_fused, v_fused = model_fused(qkv, pos) + + torch._dynamo.mark_dynamic(qkv_unfused, 0) + torch._dynamo.mark_dynamic(pos_unfused, 0) + model_unfused = torch.compile(model, backend=backend_baseline) + q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL) + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py new file mode 100644 index 0000000000000..88bb7691ec3bc --- /dev/null +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +IS_NEOX = [True, False] +EPS_VALUES = [1e-5, 1e-6] +SEEDS = [13] +CUDA_DEVICES = ["cuda:0"] + + +def _apply_qk_norm_rope( + qkv: torch.Tensor, + positions: torch.Tensor, + q_norm: RMSNorm, + k_norm: RMSNorm, + rope: RotaryEmbedding, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, +) -> torch.Tensor: + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + + q, k = rope.forward_native(positions, q, k) + return torch.cat([q, k, v], dim=-1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="fused_qk_norm_rope custom op requires cuda platform", +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox", IS_NEOX) +@pytest.mark.parametrize("eps", EPS_VALUES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_fused_qk_norm_rope_matches_reference( + device: str, + dtype: torch.dtype, + is_neox: bool, + eps: float, + seed: int, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + qkv_fused = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + ref_result = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + q_norm=q_norm, + k_norm=k_norm, + rope=rope, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + opcheck( + torch.ops._C.fused_qk_norm_rope, + ( + qkv_fused.clone(), + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ), + ) + + torch.ops._C.fused_qk_norm_rope( + qkv_fused, + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close( + qkv_fused, + ref_result, + atol=ATOL, + rtol=RTOL, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 36aab503dee70..136a3193efb5e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -329,6 +329,7 @@ def rms_norm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float ) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + # If removed, also need to remove contiguous in MatcherRMSNorm input_contiguous = input.contiguous() torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) @@ -339,6 +340,34 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + torch.ops._C.fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + position_ids, + ) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 29462d9ff0e50..126ad35e527ae 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -132,6 +132,23 @@ class FixFunctionalizationPass(VllmInductorPass): "input_global_scale", ), ) + # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. + elif at_target == torch.ops._C.fused_qk_norm_rope.default: + mutated_args = {1: "qkv"} + args = ( + "qkv", + "num_heads_q", + "num_heads_k", + "num_heads_v", + "head_dim", + "eps", + "q_weight", + "k_weight", + "cos_sin_cache", + "is_neox", + "position_ids", + ) + self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) else: continue # skip the count diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f0ad2d69fbec..1d6e297b495eb 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") +def empty_i64(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 383fe6033a6df..38eb4e5301a18 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -18,10 +18,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, ) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +ROTARY_OP = torch.ops._C.rotary_embedding.default +FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -58,6 +61,9 @@ class MatcherCustomOp(ABC): def empty(self, *args, **kws): return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + def empty_int64(self, *args, **kws): + return torch.empty(*args, dtype=torch.int64, device=self.device, **kws) + def empty_f32(self, *args, **kws): return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) @@ -66,6 +72,77 @@ class MatcherCustomOp(ABC): raise NotImplementedError +class MatcherRotaryEmbedding(MatcherCustomOp): + def __init__( + self, + is_neox: bool, + head_size: int, + num_heads: int, + num_kv_heads: int, + use_flashinfer: bool = False, + enabled: bool | None = None, + ) -> None: + if enabled is None: + enabled = RotaryEmbedding.enabled() + + super().__init__(enabled) + self.is_neox = is_neox + self.head_size = head_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_size + self.kv_size = self.num_kv_heads * self.head_size + self.rotary_dim = head_size + if use_flashinfer: + self.rotary_op = FLASHINFER_ROTARY_OP + else: + self.rotary_op = ROTARY_OP + + def inputs(self) -> list[torch.Tensor]: + positions = self.empty_int64(5) + query = self.empty(5, self.q_size) + key = self.empty(5, self.kv_size) + cos_sin_cache = self.empty(4096, self.rotary_dim) + return [positions, query, key, cos_sin_cache] + + def forward_custom( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + result = auto_functionalized( + self.rotary_op, + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + ) + query_out = result[1] + key_out = result[2] if len(result) > 2 else None + return query_out, key_out + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return RotaryEmbedding.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + cos_sin_cache, + self.is_neox, + ) + + class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: bool | None = None): if enabled is None: @@ -85,10 +162,12 @@ class MatcherRMSNorm(MatcherCustomOp): weight: torch.Tensor, ) -> torch.Tensor: result = torch.empty_like(input) + # TODO: support non-contiguous input for RMSNorm and remove this + input_contiguous = input.contiguous() _, result = auto_functionalized( RMS_OP, result=result, - input=input, + input=input_contiguous, weight=weight, epsilon=self.epsilon, ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index dfda2adf1d3b0..0c2210d72ce07 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -17,6 +17,7 @@ if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass + from .qk_norm_rope_fusion import QKNormRoPEFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -109,6 +110,9 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_qk_norm_rope_fusion: + self.passes += [QKNormRoPEFusionPass(config)] + # needs a functional graph self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py new file mode 100644 index 0000000000000..e3c399e079063 --- /dev/null +++ b/vllm/compilation/qk_norm_rope_fusion.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.attention import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + +from .fusion import empty_bf16, empty_fp32, empty_i64 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + +FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default + + +class QkNormRopePattern: + """ + Match the unfused sequence in attention blocks and replace with the fused op. + + Unfused (conceptually): + q, k, v = split(qkv, [qsz, kvsz, kvsz], -1) + qh = reshape(q, [-1, num_heads, head_dim]) + kh = reshape(k, [-1, num_kv_heads, head_dim]) + qn = rms_norm(qh, q_weight, eps) + kn = rms_norm(kh, k_weight, eps) + qf = reshape(qn, [-1, num_heads * head_dim]) + kf = reshape(kn, [-1, num_kv_heads * head_dim]) + qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox) + return qf, kf, v + + Fused replacement: + fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim, + eps, q_weight, k_weight, cos_sin_cache, is_neox, + positions.view(-1)) + return split(qkv, [qsz, kvsz, kvsz], -1) + """ + + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + eps: float, + is_neox: bool, + rope_flashinfer: bool = False, + ) -> None: + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.rmsnorm_matcher = MatcherRMSNorm(eps) + self.is_neox = is_neox + self.rope_flashinfer = rope_flashinfer + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=is_neox, + head_size=self.head_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_flashinfer=self.rope_flashinfer, + ) + + def get_inputs(self): + # Sample inputs to help pattern tracing + T = 5 + qkv = empty_bf16(T, self.q_size + 2 * self.kv_size) + positions = empty_i64(T) + q_weight = empty_bf16(1, self.head_dim) + k_weight = empty_bf16(1, self.head_dim) + if self.rope_flashinfer: + cos_sin_cache = empty_fp32(4096, self.head_dim) + else: + cos_sin_cache = empty_bf16(4096, self.head_dim) + return [ + qkv, + positions, + q_weight, + k_weight, + cos_sin_cache, + ] + + @staticmethod + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): + def wrapped(*args, **kwargs): + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + + view_to_reshape(gm) + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # split qkv -> q,k,v + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q path: view -> RMS -> view back to q.shape + q_by_head = q.view( + *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim + ) + q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) + q_flat = q_normed_by_head.view(q.shape) + + # K path: view -> RMS -> view back to k.shape + k_by_head = k.view( + *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim + ) + k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) + k_flat = k_normed_by_head.view(k.shape) + + # RoPE: apply to flattened q/k + q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache) + return q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # Run fused qk_norm_rope op + result = auto_functionalized( + FUSED_QK_ROPE_OP, + qkv=qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.eps, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + position_ids=positions.view(-1), + ) + result_qkv = result[1] + + # Split back to q,k,v and return + return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # NOTE: use fx_view_to_reshape to unify view/reshape to simplify + # pattern and increase matching opportunities + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + QkNormRopePattern.wrap_trace_fn( + pm.fwd_only, + QkNormRopePattern.fx_view_to_reshape, + ), + pm_pass, + ) + + +class QKNormRoPEFusionPass(VllmPatternMatcherPass): + """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="qk_norm_rope_fusion_pass" + ) + + dtype = config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.warning_once( + "QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype + ) + return + + # use one attn layer to get meta (such as head_dim) for QkNormRopePattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config( + config, Attention + ) + if len(attn_layers) == 0: + logger.warning_once( + "QK Norm+RoPE fusion enabled, but no Attention layers were discovered." + ) + return + layer = next(iter(attn_layers.values())) + + for epsilon in [1e-5, 1e-6]: + for neox in [True, False]: + if RotaryEmbedding.enabled(): + for rope_flashinfer in [False, True]: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + rope_flashinfer=rope_flashinfer, + ).register(self.patterns) + else: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count) + + def uuid(self): + return VllmInductorPass.hash_source(self, QkNormRopePattern) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 92cf16f259fe7..9c9557df4e738 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -129,6 +129,8 @@ class PassConfig: 8: 1, # 1MB }, }, where key is the device capability""" + enable_qk_norm_rope_fusion: bool = False + """Whether to enable the fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -182,6 +184,12 @@ class PassConfig: "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" ) + if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda(): + logger.warning_once( + "QK Norm + RoPE fusion enabled but the current platform is not " + "CUDA. The fusion will be disabled." + ) + self.enable_qk_norm_rope_fusion = False @config @@ -640,6 +648,11 @@ class CompilationConfig: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) + if self.pass_config.enable_qk_norm_rope_fusion: + # TODO(zhuhaoran): support rope native forward match and remove this. + # Linked issue: https://github.com/vllm-project/vllm/issues/28042 + self.custom_ops.append("+rotary_embedding") + if ( is_torch_equal_or_newer("2.9.0.dev") and "combo_kernels" not in self.inductor_compile_config diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2ef54e75df44e..ce4f40680b0a3 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase): head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) + @staticmethod + def forward_static( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + head_size: int, + rotary_dim: int, + cos_sin_cache: torch.Tensor, + is_neox_style: bool, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """A PyTorch-native implementation of forward().""" + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def forward_native( self, positions: torch.Tensor, @@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase): key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - # key may be None in some cases, e.g. cross-layer KV sharing - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key + return self.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + self.cos_sin_cache, + self.is_neox_style, + ) def forward_cuda( self, From 05576df85c5274ee3045d90b0779d4adeecc09b9 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Wed, 12 Nov 2025 01:05:22 +0800 Subject: [PATCH 23/98] [ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model (#24239) Signed-off-by: xuebwang-amd Co-authored-by: fxmarty-amd Co-authored-by: Cyrus Leung --- docs/features/quantization/quark.md | 34 ++++++++- tests/quantization/test_mixed_precision.py | 69 +++++++++++++++++++ .../layers/quantization/quark/quark.py | 32 +++++++-- 3 files changed, 127 insertions(+), 8 deletions(-) create mode 100755 tests/quantization/test_mixed_precision.py diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 385e3bbb8712f..be0702f4c9e16 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ --group_size 32 ``` -The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights. +The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. + +## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models + +vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including + +- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16} +- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16} + +Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput. + +There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below. + +### 1. Quantize a model using mixed precision in AMD Quark + +Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later. + +As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are: + +- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 + +### 2. inference the quantized mixed precision model in vLLM + +Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow: + +```bash +lm_eval --model vllm \ + --model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \ + --tasks mmlu \ + --batch_size auto +``` diff --git a/tests/quantization/test_mixed_precision.py b/tests/quantization/test_mixed_precision.py new file mode 100755 index 0000000000000..51526470b4233 --- /dev/null +++ b/tests/quantization/test_mixed_precision.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test quark-quantized {MXFP4, FP8} mixed precision models. + +Run `pytest tests/quantization/test_mixed_precision.py`. + +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import lm_eval +import pytest +from packaging import version + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False" + ) + + +TEST_CONFIGS = { + # Mixed-precision (AMP) model + # - Demonstrates end-to-end pipeline functionality + "amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72}, + # Non-mixed-precision (PTQ) model + # - Reference for pipeline compatibility verification -> No conflicts or breakings + "amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": { + "arc_challenge": 0.53, + "mmlu": 0.61, + }, +} + + +@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items()) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict): + results = lm_eval.simple_evaluate( + model="vllm", + model_args=EvaluationConfig(model_name).get_model_args(), + tasks=list(accuracy_numbers.keys()), + batch_size=8, + ) + + rtol = 0.05 + + for task, expect_accuracy in accuracy_numbers.items(): + measured_accuracy = results["results"][task]["acc,none"] + assert ( + measured_accuracy - rtol < expect_accuracy + and measured_accuracy + rtol > expect_accuracy + ), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d5459594b7983..095a66ef10f9a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -114,7 +114,14 @@ class QuarkConfig(QuantizationConfig): layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) - if not kv_cache_set.issubset(layer_quant_set): + if not ( + kv_cache_set.issubset(layer_quant_set) + or any( + fnmatch.fnmatchcase(layer_quant, pat) + for layer_quant in list(layer_quant_set) + for pat in list(kv_cache_set) + ) + ): raise ValueError( "The Quark quantized model has the " "kv_cache_group parameter setting, " @@ -124,10 +131,15 @@ class QuarkConfig(QuantizationConfig): ) q_configs = [ - cast(dict[str, Any], layer_quant_config.get(name)) - for name in kv_cache_group + quant_cfg + for name, quant_cfg in layer_quant_config.items() + if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group) ] - if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): + + if not all( + deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"]) + for q_config in q_configs + ): raise ValueError( "The quantization method used for kv_cache should " "be the same, but the quantization method for the " @@ -312,9 +324,15 @@ class QuarkConfig(QuantizationConfig): layer_quant_config = cast( dict[str, Any], self.quant_config.get("layer_quant_config") ) - for name_pattern in layer_quant_config: - if fnmatch.fnmatch(layer_name, name_pattern): - return layer_quant_config[name_pattern] + + def _matches_pattern(layer_name, pattern): + if "*" not in pattern: + return layer_name in pattern + return fnmatch.fnmatch(layer_name, pattern) + + for name_pattern, config in layer_quant_config.items(): + if _matches_pattern(layer_name, name_pattern): + return config layer_type = cast(str, type(module)) layer_type_quant_config = cast( From 5a1271d83a65be5ed8dc3e4c990ed42074197db3 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Wed, 12 Nov 2025 01:06:00 +0800 Subject: [PATCH 24/98] [Quantization] fix attention quantization of gpt_oss model (#27334) Signed-off-by: xuebwang-amd --- .../test_gpt_oss_attn_quantization.py | 80 +++++++++++++++++++ .../layers/quantization/mxfp4.py | 15 +++- vllm/model_executor/models/gpt_oss.py | 10 ++- 3 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 tests/models/quantization/test_gpt_oss_attn_quantization.py diff --git a/tests/models/quantization/test_gpt_oss_attn_quantization.py b/tests/models/quantization/test_gpt_oss_attn_quantization.py new file mode 100644 index 0000000000000..780165ea2ba7a --- /dev/null +++ b/tests/models/quantization/test_gpt_oss_attn_quantization.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test attention quantization of gpt-oss model. +The qkv_proj and o_proj in self_attention can be either quantized or excluded. + +Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`. + +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import huggingface_hub +import lm_eval +import pytest +from packaging import version + +MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"] + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + + +def has_huggingface_access(repo): + try: + huggingface_hub.list_repo_refs(repo) + return True + except huggingface_hub.errors.RepositoryNotFoundError: + return False + + +HF_HUB_AMD_ORG_ACCESS = all( + [has_huggingface_access(model_name) for model_name in MODEL_NAMES] +) + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False" + ) + + +EXPECTED_ACCURACIES = {"arc_challenge": 0.20} + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not HF_HUB_AMD_ORG_ACCESS, + reason="Read access to huggingface.co/amd is required for this test.", +) +@pytest.mark.parametrize("model_name", MODEL_NAMES) +@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items()) +def test_gpt_oss_attention_quantization( + model_name: str, task_name: str, expected_accuracy: float +): + measured_accuracy = lm_eval.simple_evaluate( + model="vllm", + model_args=EvaluationConfig(model_name).get_model_args(), + tasks=task_name, + batch_size="auto", + )["results"][task_name]["acc,none"] + + rtol = 0.05 + assert ( + measured_accuracy - rtol < expected_accuracy + and measured_accuracy + rtol > expected_accuracy + ), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4e51249f2d25b..8d7297a0a1b3b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -190,14 +190,25 @@ class Mxfp4Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() - raise NotImplementedError("Mxfp4 linear layer is not implemented") + # TODO: Add support for MXFP4 Linear Method. + # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation + # if you are interested in enabling MXFP4 here. + logger.warning_once( + "MXFP4 linear layer is not implemented - falling back to " + "UnquantizedLinearMethod." + ) + return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): if current_platform.is_xpu(): return IpexMxfp4MoEMethod(layer.moe_config) else: return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): - raise NotImplementedError("Mxfp4 attention layer is not implemented") + # TODO: Add support for MXFP4 Attention. + logger.warning_once( + "MXFP4 attention layer is not implemented. " + "Skipping quantization for this layer." + ) return None diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 04038ae74882d..291ac833f26ad 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, + quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() @@ -207,7 +208,10 @@ class TransformerBlock(torch.nn.Module): self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention( - config, prefix=f"{prefix}.attn", cache_config=cache_config + config, + prefix=f"{prefix}.attn", + quant_config=quant_config, + cache_config=cache_config, ) self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -243,6 +247,7 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( @@ -254,6 +259,7 @@ class GptOssModel(nn.Module): lambda prefix: TransformerBlock( vllm_config, prefix=prefix, + quant_config=self.quant_config, ), prefix=f"{prefix}.layers", ) @@ -645,7 +651,7 @@ class GptOssModel(nn.Module): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): - packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ From e55342491968a56d39dc8e03e6cf39d12fef5dcd Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Tue, 11 Nov 2025 09:09:47 -0800 Subject: [PATCH 25/98] [CI/Build] Refactor Attention backend for test_prefix_prefill from xformers to SDPA (#28424) Signed-off-by: zhewenli Signed-off-by: Roger Wang Co-authored-by: Roger Wang --- .../kernels/attention/test_prefix_prefill.py | 312 +++++++++++------- 1 file changed, 195 insertions(+), 117 deletions(-) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 65972d02f2f66..78cdbbbf7379d 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -8,10 +8,8 @@ from collections.abc import Callable import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +import torch.nn.functional as F -from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform @@ -28,6 +26,74 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] OPS = [chunked_prefill_paged_decode, context_attention_fwd] +def create_causal_attention_mask_for_sdpa( + query_lens: list[int], + seq_lens: list[int], + sliding_window: int = 0, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + total_queries = sum(query_lens) + total_keys = sum(seq_lens) + + # Create a mask filled with -inf + mask = torch.full( + (total_queries, total_keys), float("-inf"), device=device, dtype=dtype + ) + + query_start = 0 + key_start = 0 + + for query_len, seq_len in zip(query_lens, seq_lens): + query_end = query_start + query_len + key_end = key_start + seq_len + q_indices = torch.arange(query_len, device=device) + k_indices = torch.arange(seq_len, device=device) + q_pos_in_seq = seq_len - query_len + q_indices + + valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None] + + if sliding_window > 0: + valid_mask &= k_indices[None, :] >= ( + q_pos_in_seq[:, None] - sliding_window + 1 + ) + + mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0 + + query_start = query_end + key_start = key_end + + return mask + + +def create_alibi_causal_mask( + query_len: int, + seq_len: int, + alibi_slopes: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + query_pos = torch.arange( + seq_len - query_len, seq_len, device=device, dtype=torch.float32 + ) + key_pos = torch.arange(seq_len, device=device, dtype=torch.float32) + + rel_pos = key_pos[None, :] - query_pos[:, None] + + # Apply ALiBi slopes: [num_heads, query_len, seq_len] + alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :] + alibi_bias = alibi_bias.to(dtype) + + # Apply causal mask: prevent attending to future positions + # causal_mask[i, j] = True if key_pos[j] <= query_pos[i] + causal_mask = key_pos[None, :] <= query_pos[:, None] + alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf")) + + # Add batch dimension: [1, num_heads, query_len, seq_len] + # SDPA expects batch dimension even for single sequences + return alibi_bias.unsqueeze(0) + + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -52,6 +118,13 @@ def test_contexted_kv_attention( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -96,16 +169,16 @@ def test_contexted_kv_attention( ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -189,56 +262,57 @@ def test_contexted_kv_attention( scale = float(1.0 / (head_size**0.5)) - attn_op = xops.fmha.cutlass.FwOp() - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view( - query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] - ) - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens + # Reshape for SDPA: (seq_len, num_heads, head_size) -> + # (1, num_heads, seq_len, head_size) + query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size) + query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, num_tokens, head_size ) - if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + + # Expand key and value for GQA/MQA to match query heads + key_sdpa = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) + + value_sdpa = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) + value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) + + attn_mask = create_causal_attention_mask_for_sdpa( + query_lens, seq_lens, sliding_window, device=device, dtype=dtype + ) + + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() start_time = time.time() - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") - output_ref = output_ref.reshape(output.shape) + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") + + # Reshape output back to (num_tokens, num_heads, head_size) + output_ref = output_ref.view(num_heads, num_tokens, head_size) + output_ref = output_ref.permute(1, 0, 2).contiguous() atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -331,16 +412,16 @@ def test_contexted_kv_attention_alibi( ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -423,78 +504,75 @@ def test_contexted_kv_attention_alibi( print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) - # NOTE(DefTruth): In order to reuse _make_alibi_bias function, - # we have to pad query tensor before MQA/GQA expanding. - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat( - [ - torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...], - ], - dim=0, - ) - seq_start += seq_len - query_start += query_len - query = query_pad + # Prepare query, key, value for SDPA + # Expand key and value for GQA/MQA to match query heads + key_expanded = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value_expanded = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - # [seq, num_kv_heads, num_queries_per_kv, dk]=> - # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the - # codebase. We save some time reshaping alibi matrix at runtime. - key = key.reshape(key.shape[0], -1, key.shape[-1]) - value = value.reshape(value.shape[0], -1, value.shape[-1]) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 + + torch.cuda.synchronize() start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/v1/attention/backends/xformers.py#L343 + + query_start = 0 + key_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward( - query[:, seq_start:seq_end], - key[:, seq_start:seq_end], - value[:, seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, + key_end = key_start + seq_len + + # Get query, key, value for this sequence + q = query[query_start:query_end] # [query_len, num_heads, head_size] + k = key_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + v = value_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + + # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size) + q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size) + q_sdpa = ( + q_sdpa.permute(1, 2, 0, 3) + .reshape(1, num_heads, query_len, head_size) + .contiguous() + ) + + k_sdpa = ( + k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() + ) + v_sdpa = ( + v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() + ) + + # Create ALiBi causal mask for this sequence using utility function + alibi_mask = create_alibi_causal_mask( + query_len, seq_len, alibi_slopes, device, dtype + ) + + # Compute attention + out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + attn_mask=alibi_mask, + dropout_p=0.0, scale=scale, ) - out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size - ) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) - seq_start += seq_len - query_start += query_len + + # Reshape output back to [query_len, num_heads, head_size] + out = out.view(num_heads, query_len, head_size).permute(1, 0, 2) + output_ref[query_start:query_end].copy_(out) + + query_start = query_end + key_start = key_end + torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) From 684f2545851ee0ee49be9a80545ed497324f1a96 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 11 Nov 2025 11:13:51 -0600 Subject: [PATCH 26/98] Prefer FlashAttention MLA as default over FlashMLA (#27363) Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 43daf5e75b665..22c6dde754d01 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -55,15 +55,15 @@ def _get_backend_priorities( return [ AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.FLASHINFER_MLA, - AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.FLASHMLA_SPARSE, ] else: return [ - AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.FLASHMLA_SPARSE, From 6c3c0f8235cacce28982687e362b80d953ea7617 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:02:23 -0800 Subject: [PATCH 27/98] [Kernel] Optimize rms_norm kernel (#27931) Signed-off-by: Xin Yang --- csrc/dispatch_utils.h | 29 ++++++++++++++++++++++ csrc/layernorm_kernels.cu | 39 +++++++++++++++++++++--------- csrc/layernorm_quant_kernels.cu | 43 ++++++++++++++++++++++----------- 3 files changed, 86 insertions(+), 25 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 995374a50b037..9ae0ed975edde 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -88,3 +88,32 @@ #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \ + switch (VEC_SIZE) { \ + case 16: { \ + constexpr int vec_size = 16; \ + __VA_ARGS__(); \ + break; \ + } \ + case 8: { \ + constexpr int vec_size = 8; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int vec_size = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + case 2: { \ + constexpr int vec_size = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + default: { \ + constexpr int vec_size = 1; \ + __VA_ARGS__(); \ + break; \ + } \ + } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 8cfcf9f41283a..48771e4b3aff9 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -10,7 +10,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -21,7 +21,6 @@ __global__ void rms_norm_kernel( float variance = 0.0f; const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -45,10 +44,20 @@ __global__ void rms_norm_kernel( } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + scalar_t* out_row = out + blockIdx.x * hidden_size; + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + auto* v_out = reinterpret_cast*>(out_row); + for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) { + vec_n_t dst; + vec_n_t src1 = v_in[i]; + vec_n_t src2 = v_w[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j]; + } + v_out[i] = dst; } } @@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input_view.numel() / hidden_size; int64_t input_stride = input_view.stride(-2); + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input_view.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input_view.data_ptr(), - input_stride, weight.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); } diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0f7f034ee180b..0880b8d50a795 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -18,7 +18,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel( const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel( // invert scale to avoid division float const scale_inv = 1.0f / *scale; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; - out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) { + vec_n_t src1 = v_in[idx]; + vec_n_t src2 = v_w[idx]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j]; + out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] = + scaled_fp8_conversion(out_norm, scale_inv); + } } } @@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { - vllm::rms_norm_static_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - input_stride, weight.data_ptr(), - scale.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + input_stride, weight.data_ptr(), + scale.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); }); } From d5edcb86781ea56f1eb0c9c5d7482a7cae00ec17 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 12 Nov 2025 02:18:02 +0800 Subject: [PATCH 28/98] [BugFix] Fix Siglip2Attention on XPU (#28448) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/siglip2navit.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c20bcd975ca30..29dd164ad37fd 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from .vision import get_vit_attn_backend @@ -188,7 +189,7 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend: + if is_flash_attn_backend and not current_platform.is_xpu(): from flash_attn.layers.rotary import apply_rotary_emb apply_rotary_emb_func = apply_rotary_emb @@ -306,7 +307,13 @@ class Siglip2Attention(nn.Module): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: attn_output = self.flash_attn_varlen_func( - queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + queries, + keys, + values, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, ).reshape(seq_length, -1) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. From 76e4dcf225e4de115bdc20b00a78d49bec767c09 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 11 Nov 2025 18:26:04 +0000 Subject: [PATCH 29/98] [Misc] Remove unused attention prefix prefill ops functions (#26971) Signed-off-by: Lukas Geiger --- vllm/attention/ops/prefix_prefill.py | 210 ------------------ .../compressed_tensors_moe.py | 3 - 2 files changed, 213 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index addf1d9dea73e..f101d5c4a9278 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -335,216 +335,6 @@ def _fwd_kernel( return -@triton.jit -def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load( - B_Loc - + cur_batch * stride_b_loc_b - + ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0, - ).to(tl.int64) - off_k = ( - bn[None, :] * stride_k_cache_bs - + cur_kv_head * stride_k_cache_h - + (offs_d[:, None] // x) * stride_k_cache_d - + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl - + (offs_d[:, None] % x) * stride_k_cache_x - ) - off_v = ( - bn[:, None] * stride_v_cache_bs - + cur_kv_head * stride_v_cache_h - + offs_d[None, :] * stride_v_cache_d - + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl - ) - k = tl.load( - K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where( - (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") - ) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = ( - offs_n[None, :] * stride_kbs - + cur_kv_head * stride_kh - + offs_d[:, None] * stride_kd - ) - off_v = ( - offs_n[:, None] * stride_vbs - + cur_kv_head * stride_vh - + offs_d[None, :] * stride_vd - ) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store( - out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len - ) - return - - @triton.jit def _fwd_kernel_alibi( Q, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 59567f2ca13c7..6257a410e9432 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -98,9 +98,6 @@ __all__ = [ class CompressedTensorsMoEMethod(FusedMoEMethodBase): - def __init_(self, moe: FusedMoEConfig): - super().__init__(moe) - @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 From 4228be7959e98e57d88501bd97aca7ef34ff562e Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 10:28:47 -0800 Subject: [PATCH 30/98] [Perf] Use np.ndarray instead of list[list[int]] to reduce GC overhead (#28245) Signed-off-by: Jialin Ouyang --- tests/v1/engine/utils.py | 7 ++++--- vllm/v1/engine/logprobs.py | 7 ++++++- vllm/v1/outputs.py | 13 +++++++------ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 23684a2c55cef..3541ef89bfc14 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -5,6 +5,7 @@ import random from dataclasses import dataclass from typing import TypeAlias +import numpy as np import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -369,9 +370,9 @@ class MockEngineCore: self.generated_logprobs_raw[req_idx][token_idx] ) logprobs = LogprobsLists( - [logprobs_token_ids_], - [logprobs_], - [sampled_token_ranks_], + np.array([logprobs_token_ids_]), + np.array([logprobs_]), + np.array([sampled_token_ranks_]), ) else: logprobs = None diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 4c5955d7ee2e5..b618d23472651 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -74,7 +74,12 @@ class LogprobsProcessor: token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): + for rank_np, logprobs_np, token_ids_np in zip( + ranks_lst, logprobs_lst, token_ids_lst + ): + rank = rank_np.tolist() + logprobs = logprobs_np.tolist() + token_ids = token_ids_np.tolist() # Detokenize (non-incrementally). decoded_tokens = ( NONES diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b5cba96e1026f..5f65e4ee0d1f3 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple +import numpy as np import torch if TYPE_CHECKING: @@ -15,11 +16,11 @@ else: class LogprobsLists(NamedTuple): # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprob_token_ids: list[list[int]] + logprob_token_ids: np.ndarray # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprobs: list[list[float]] + logprobs: np.ndarray # [num_reqs x num_generated_tokens] - sampled_token_ranks: list[int] + sampled_token_ranks: np.ndarray # [num_reqs] # Used for slicing the logprobs in cases like speculative # decoding where the number of generated tokens may be @@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple): def tolists(self, cu_num_generated_tokens: list[int] | None = None): return LogprobsLists( - self.logprob_token_ids.tolist(), - self.logprobs.tolist(), - self.selected_token_ranks.tolist(), + self.logprob_token_ids.cpu().numpy(), + self.logprobs.cpu().numpy(), + self.selected_token_ranks.cpu().numpy(), cu_num_generated_tokens, ) From de120bc94f2e51633824093c626423ec8e7cb3a9 Mon Sep 17 00:00:00 2001 From: Canlin Guo <961750412@qq.com> Date: Wed, 12 Nov 2025 02:57:12 +0800 Subject: [PATCH 31/98] [V0 deprecation] Clean up num_prefill_tokens logic for V0 (#28203) Signed-off-by: gcanlin --- vllm/forward_context.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ef37cf862c9fe..44bc2a4cda311 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple import torch @@ -185,18 +185,13 @@ class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] """ - Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. Set dynamically for each forward pass """ - attn_metadata: Union[ - "AttentionMetadata", - dict[str, "AttentionMetadata"], - list[dict[str, "AttentionMetadata"]], - ] + attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -324,14 +319,7 @@ def set_forward_context( finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = ( - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - ) - else: - # for v1 attention backends - batchsize = num_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch From 8c32c6e4b485f1cae1a1dc8a3f9895cf63f3e7af Mon Sep 17 00:00:00 2001 From: Jie Luo <65482183+Livinfly@users.noreply.github.com> Date: Wed, 12 Nov 2025 02:59:16 +0800 Subject: [PATCH 32/98] [Misc] fix typo in DCP comment (#28389) Signed-off-by: Livinfly --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b4cb5c200da38..19bd102cb1e30 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2000,7 +2000,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): decode_q, kv_cache, attn_metadata, layer ) - # recorect dcp attn_out with lse. + # correct dcp attn_out with lse. if self.dcp_world_size > 1: attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) From 9d1c47470430ba31c02946aa1fd01aadf6e18b91 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 12 Nov 2025 03:06:21 +0800 Subject: [PATCH 33/98] [LoRA][1/N]Remove LoRA extra vocab (#28382) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/apertus.py | 30 +++------------- vllm/model_executor/models/arcee.py | 10 ++---- vllm/model_executor/models/arctic.py | 6 ++-- vllm/model_executor/models/aria.py | 8 ++--- vllm/model_executor/models/baichuan.py | 4 +-- vllm/model_executor/models/bailing_moe.py | 2 -- vllm/model_executor/models/bamba.py | 30 ++++------------ vllm/model_executor/models/chameleon.py | 8 ++--- vllm/model_executor/models/chatglm.py | 3 +- vllm/model_executor/models/commandr.py | 19 ++++------- vllm/model_executor/models/dbrx.py | 9 ++--- vllm/model_executor/models/exaone.py | 27 +++------------ vllm/model_executor/models/exaone4.py | 26 +++----------- vllm/model_executor/models/falcon_h1.py | 31 ++++------------- vllm/model_executor/models/gemma.py | 2 -- vllm/model_executor/models/gemma2.py | 3 +- vllm/model_executor/models/gemma3.py | 3 +- vllm/model_executor/models/gemma3n.py | 3 +- vllm/model_executor/models/glm4.py | 2 -- vllm/model_executor/models/gpt_bigcode.py | 20 +++-------- vllm/model_executor/models/granitemoe.py | 27 +++------------ .../model_executor/models/granitemoehybrid.py | 27 +++------------ .../model_executor/models/granitemoeshared.py | 28 +++------------ vllm/model_executor/models/grok1.py | 26 ++++---------- vllm/model_executor/models/hunyuan_v1.py | 21 ++++-------- vllm/model_executor/models/internlm2.py | 2 -- vllm/model_executor/models/jamba.py | 30 ++++------------ vllm/model_executor/models/kimi_vl.py | 10 ++---- vllm/model_executor/models/lfm2.py | 31 +++-------------- vllm/model_executor/models/lfm2_moe.py | 32 ++++------------- vllm/model_executor/models/llama_eagle3.py | 3 -- vllm/model_executor/models/longcat_flash.py | 3 +- vllm/model_executor/models/mamba.py | 29 ++++------------ vllm/model_executor/models/mamba2.py | 28 +++------------ vllm/model_executor/models/medusa.py | 12 ++----- vllm/model_executor/models/mimo.py | 2 -- vllm/model_executor/models/minicpm.py | 30 ++++------------ vllm/model_executor/models/minicpm_eagle.py | 29 ++++------------ vllm/model_executor/models/minimax_text_01.py | 11 ++---- vllm/model_executor/models/mlp_speculator.py | 1 - vllm/model_executor/models/molmo.py | 3 +- vllm/model_executor/models/nemotron.py | 30 ++++------------ vllm/model_executor/models/nemotron_h.py | 30 ++++------------ vllm/model_executor/models/nemotron_nas.py | 31 ++++------------- vllm/model_executor/models/olmo.py | 4 +-- vllm/model_executor/models/olmo2.py | 2 -- vllm/model_executor/models/ouro.py | 2 -- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/phi3v.py | 1 - vllm/model_executor/models/phi4mm.py | 14 ++------ vllm/model_executor/models/phimoe.py | 34 ++++--------------- vllm/model_executor/models/plamo2.py | 11 ++---- vllm/model_executor/models/qwen2.py | 2 -- vllm/model_executor/models/qwen2_rm.py | 2 -- vllm/model_executor/models/qwen3.py | 2 -- vllm/model_executor/models/qwen3_next.py | 30 ++++------------ vllm/model_executor/models/qwen3_next_mtp.py | 23 ++++--------- vllm/model_executor/models/qwen3_vl.py | 2 -- vllm/model_executor/models/seed_oss.py | 2 -- vllm/model_executor/models/solar.py | 30 ++++------------ vllm/model_executor/models/starcoder2.py | 12 ++----- vllm/model_executor/models/step3_text.py | 16 ++------- .../models/transformers/causal.py | 3 +- vllm/model_executor/models/whisper.py | 6 ++-- vllm/model_executor/models/zamba2.py | 28 +++------------ 65 files changed, 197 insertions(+), 754 deletions(-) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 72e5ddcf1abeb..233b8c79f2992 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -346,24 +345,18 @@ class ApertusModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -518,9 +511,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, @@ -529,20 +520,9 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -551,7 +531,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 08bf1a6aad75b..f33970aff279c 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -200,7 +199,6 @@ class ArceeModel(nn.Module): self.quant_config = quant_config self.config = config self.vocab_size = config.vocab_size - self.org_vocab_size = config.vocab_size # Word embeddings (parallelized if using pipeline parallel) if get_pp_group().is_first_rank or ( @@ -209,7 +207,6 @@ class ArceeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -383,13 +380,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if get_pp_group().is_last_rank: # Determine vocabulary size (including any LoRA extra tokens # for padded LM head) - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=vllm_config.quant_config, bias=getattr(config, "lm_head_bias", False), prefix=f"{prefix}.lm_head", @@ -399,7 +393,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: # Placeholder for lm_head on non-last ranks diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index bb505219ea17c..ae3b96c83509d 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -490,10 +490,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - self.unpadded_vocab_size = config.vocab_size - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 222a425790543..fe37487d6ed88 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -547,18 +547,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) - self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + self.vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale) def _parse_and_validate_image_input( self, **kwargs: object diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 39990b9fd6837..dac012eb9f829 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -402,9 +402,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.model = BaiChuanModel( diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 1549c653482f6..641bdb69c366c 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -581,10 +581,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): config = vllm_config.model_config.hf_config.get_text_config() vllm_config.model_config.hf_config = config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings self.model = BailingMoeModel( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index bc7dbb618f65c..4a2b3da1c194d 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -284,21 +283,14 @@ class BambaModel(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): @@ -478,7 +470,7 @@ class BambaForCausalLM( config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -488,24 +480,14 @@ class BambaForCausalLM( self.model = BambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 54ff6991fa702..64f73e938bf6c 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -963,9 +963,9 @@ class ChameleonForConditionalGeneration( self.model = ChameleonModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -973,9 +973,7 @@ class ChameleonForConditionalGeneration( self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index bcbe82b78c3b1..ccf7c93001664 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -433,10 +433,9 @@ class ChatGLMBaseModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config - self.lora_config = lora_config self.multimodal_config = multimodal_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 75459601f76b0..6ae1dc3560827 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -288,17 +288,12 @@ class CohereModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.quant_config = quant_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) @@ -424,17 +419,15 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config # currently all existing command R models have `tie_word_embeddings` # enabled assert config.tie_word_embeddings - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.quant_config = quant_config self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale + config.vocab_size, scale=config.logit_scale ) self.model = CohereModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 22095d05848ce..70999501f4c69 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -441,21 +440,17 @@ class DbrxForCausalLM(nn.Module, SupportsPP): if config.tie_word_embeddings: raise ValueError("tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config - self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") ) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 84fb52d138545..b9c7a520caffb 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -323,16 +322,11 @@ class ExaoneModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.wte = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank @@ -340,7 +334,6 @@ class ExaoneModel(nn.Module): self.wte = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -489,10 +482,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.transformer = ExaoneModel( @@ -500,18 +492,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -520,7 +503,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index d5e4d9a1486f7..6a5c888c095ae 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -311,23 +310,17 @@ class Exaone4Model(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -476,10 +469,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Exaone4Model( @@ -487,18 +478,9 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -507,7 +489,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index ac5846cfd8695..38838be29093e 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -424,21 +423,15 @@ class FalconH1Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier else: @@ -572,7 +565,7 @@ class FalconH1ForCausalLM( config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -584,21 +577,11 @@ class FalconH1ForCausalLM( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.tie_word_embeddings = config.tie_word_embeddings - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head_multiplier = config.lm_head_multiplier @@ -607,7 +590,7 @@ class FalconH1ForCausalLM( # Used to track and store by the Mamba cache between steps. self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=config.lm_head_multiplier, ) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 46b111f4d9396..caeee7c2e1ecc 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -382,12 +382,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings - self.lora_config = lora_config self.quant_config = quant_config self.model = GemmaModel( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 1938efd4895e5..efd01535fc3ef 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -393,8 +393,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 80ec40f478c6d..213f9f562f8a0 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -524,8 +524,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 547884f393eb0..22d51ab762692 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -1114,8 +1114,7 @@ class Gemma3nForCausalLM(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config self.cache_config = vllm_config.cache_config diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index d7fd2b109d24f..4172f16737c18 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -248,10 +248,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Glm4Model( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f2c8e2aeb8225..99cdaabb98dfe 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -207,18 +207,13 @@ class GPTBigCodeModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.wte = VocabParallelEmbedding( self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size ) @@ -290,10 +285,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.transformer = GPTBigCodeModel( @@ -305,15 +298,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = ParallelLMHead( self.transformer.vocab_size, self.transformer.embed_dim, - org_num_embeddings=self.config.vocab_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index e683f30805f37..c5b36c362ee32 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -296,22 +295,15 @@ class GraniteMoeModel(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier @@ -518,26 +510,16 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = GraniteMoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -545,7 +527,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index bac64eec8c558..3a98abed76fdf 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -334,22 +333,15 @@ class GraniteMoeHybridModel(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier @@ -658,7 +650,7 @@ class GraniteMoeHybridForCausalLM( config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config self.config = config @@ -666,26 +658,17 @@ class GraniteMoeHybridForCausalLM( self.model = GraniteMoeHybridModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index e222109f2a949..e08e9f73ec879 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -159,23 +158,16 @@ class GraniteMoeSharedModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) self.embedding_multiplier = config.embedding_multiplier @@ -281,26 +273,16 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = GraniteMoeSharedModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -308,7 +290,7 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index d77a0bc2993a0..0770e03b5356e 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -305,18 +304,13 @@ class Grok1Model(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( config, "embedding_multiplier_scale", DEFAULT_EMBEDDING_MULTIPLIER_SCALE ) @@ -324,7 +318,6 @@ class Grok1Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) @@ -499,25 +492,18 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = Grok1Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -529,7 +515,7 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, self.output_multiplier_scale + config.vocab_size, scale=self.output_multiplier_scale ) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 8fa9776bd0186..a05a00932c13b 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -57,7 +57,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -606,7 +605,7 @@ class HunYuanModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + eplb_config = vllm_config.parallel_config.eplb_config enable_eplb = vllm_config.parallel_config.enable_eplb self.num_redundant_experts = eplb_config.num_redundant_experts @@ -614,20 +613,15 @@ class HunYuanModel(nn.Module): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -937,12 +931,9 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -951,7 +942,7 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c5bbd5497a146..d856f5c79e33d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -330,11 +330,9 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - self.lora_config = lora_config self.model = model_type( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0cb993901fd38..70f52e3106f81 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -307,21 +306,14 @@ class JambaModel(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} @@ -492,7 +484,7 @@ class JambaForCausalLM( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -503,24 +495,14 @@ class JambaForCausalLM( self.model = JambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b79bdf8595ca9..fa04f60b9c140 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -60,7 +60,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.model_loader.weight_utils import ( @@ -347,13 +346,10 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): vllm_config=sub_vllm_config, prefix=maybe_prefix(prefix, "language_model"), ) - self.unpadded_vocab_size = config.text_config.vocab_size if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) else: @@ -362,9 +358,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.language_model.make_empty_intermediate_tensors ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id def _parse_and_validate_image_input( diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 5684b9a891257..21d71887178e7 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.mamba.short_conv import ShortConv from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -316,16 +315,10 @@ class Lfm2Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size @@ -483,7 +476,7 @@ class Lfm2ForCausalLM( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( "Lfm2 currently does not support prefix caching" ) @@ -495,21 +488,9 @@ class Lfm2ForCausalLM( ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = self.config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -517,9 +498,7 @@ class Lfm2ForCausalLM( else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 02a490e9c7fd9..b191164671050 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.mamba.short_conv import ShortConv from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -423,20 +422,15 @@ class Lfm2MoeModel(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size @@ -662,7 +656,7 @@ class Lfm2MoeForCausalLM( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( "Lfm2Moe currently does not support prefix caching" ) @@ -674,21 +668,9 @@ class Lfm2MoeForCausalLM( ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = self.config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -696,9 +678,7 @@ class Lfm2MoeForCausalLM( else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index da4bbda186b17..b8b9cc76d08d2 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -252,8 +251,6 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): self.lm_head = ParallelLMHead( self.config.draft_vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.draft_vocab_size, - padding_size=(DEFAULT_VOCAB_PADDING_SIZE), prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 5671347c00a23..b848ae6e822f1 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -554,7 +554,6 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config config.intermediate_size = ( @@ -562,7 +561,7 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if hasattr(config, "ffn_hidden_size") else config.intermediate_size ) - self.lora_config = lora_config + self.quant_config = quant_config self.model = FlashModel( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f684203f6d35e..02abe693e071d 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -110,18 +109,12 @@ class MambaModel(nn.Module): is_lora_enabled = bool(lora_config) self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -199,7 +192,7 @@ class MambaForCausalLM( ): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.scheduler_config = vllm_config.scheduler_config super().__init__() @@ -209,27 +202,17 @@ class MambaForCausalLM( self.backbone = MambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if config.tie_word_embeddings: self.lm_head = self.backbone.embeddings else: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 8ba8af66635b3..d19480b064e05 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -107,18 +106,12 @@ class Mamba2Model(nn.Module): assert not is_lora_enabled self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -238,7 +231,7 @@ class Mamba2ForCausalLM( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -249,27 +242,16 @@ class Mamba2ForCausalLM( self.backbone = Mamba2Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 7e1d2bf14bb5c..fd7fc2c73f16e 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -9,7 +9,6 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -70,14 +69,11 @@ class Medusa(nn.Module): ) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size - self.unpadded_vocab_size = self.truncated_vocab_size if getattr(config, "original_lm_head", False): self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + self.truncated_vocab_size, config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_heads = [self.lm_head for _ in range(self.config.num_heads)] @@ -85,10 +81,8 @@ class Medusa(nn.Module): self.lm_heads = nn.ModuleList( [ ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, f"lm_heads.{i}"), ) for i in range(self.config.num_heads) @@ -97,7 +91,7 @@ class Medusa(nn.Module): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale + config.vocab_size, self.truncated_vocab_size, logit_scale ) # Token map is a idx to token mapping to reduce the vocab size for diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 726752a77e0dc..666ac90c44293 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -151,10 +151,8 @@ class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): nn.Module.__init__(self) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 85d3542317a1d..d9f0b477180e4 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -405,22 +404,16 @@ class MiniCPMModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) @@ -588,13 +581,13 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.prefix = prefix self.vllm_config = vllm_config self.config = config - self.lora_config = lora_config + self.cache_config = cache_config self.quant_config = quant_config @@ -602,18 +595,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -621,7 +605,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 463af9bbe1399..6efc61e25ea1b 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -151,18 +150,13 @@ class EagleMiniCPMModel(nn.Module): config = vllm_config.speculative_config.draft_model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.fc = torch.nn.Linear( self.config.hidden_size * 2, self.config.hidden_size, bias=False ) @@ -171,7 +165,6 @@ class EagleMiniCPMModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config, start_layer) @@ -321,12 +314,11 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config = vllm_config.speculative_config.draft_model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.prefix = prefix self.vllm_config = vllm_config self.config = config - self.lora_config = lora_config + self.cache_config = cache_config self.quant_config = quant_config @@ -340,18 +332,9 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): start_layer=target_layer_num, ) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -359,7 +342,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index e262012dcd526..1409a309f3aeb 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -669,16 +668,14 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config if not hasattr(config, "sliding_window"): config.sliding_window = None self.CONCAT_FFN = True - self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len self.model = MiniMaxText01Model( @@ -686,15 +683,13 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size + config.vocab_size, self.config.vocab_size ) else: diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 4901ac74fb28b..48604d8e51031 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -123,7 +123,6 @@ class MLPSpeculator(nn.Module): VocabParallelEmbedding( config.vocab_size, self.inner_dim, - org_num_embeddings=config.vocab_size, ) for _ in range(self.max_speculative_tokens) ] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index dce94d181c4cd..7a9e3d81b73a1 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1404,10 +1404,9 @@ class MolmoForCausalLM( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - lora_config = vllm_config.lora_config + self.config = config self.multimodal_config = multimodal_config - self.lora_config = lora_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 845798b18d1b3..17e8e7f28258d 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -319,24 +318,18 @@ class NemotronModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) else: self.embed_tokens = PPMissingLayer() @@ -467,29 +460,20 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + assert isinstance(config, NemotronConfig) self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = NemotronModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -498,7 +482,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index fb58d01be7ba1..8ef3eee173eb2 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -513,21 +512,14 @@ class NemotronHModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.has_moe = "E" in config.hybrid_override_pattern @@ -768,7 +760,7 @@ class NemotronHForCausalLM( config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -779,24 +771,14 @@ class NemotronHForCausalLM( self.model = NemotronHModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 17e009612df43..acd0d0c982348 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -250,25 +249,19 @@ class DeciModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -437,29 +430,17 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -468,7 +449,7 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 390a91d3425ce..cb47f76a27ff5 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -368,11 +368,9 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 7e39f6dff25e7..2aa01adebc9f1 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -408,11 +408,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=vllm_config.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index b8dad909c5470..cc7947df50aea 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -462,10 +462,8 @@ class OuroForCausalLM(nn.Module, SupportsLoRA): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = OuroModel( diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 34db124b6447c..e76fb1904727c 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -323,11 +323,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config # lm_head use bias, cannot share word embeddings assert not config.tie_word_embeddings - self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b86fe67fb4768..a7b28bd18cc7a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -591,7 +591,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index acad72b058fcd..c2a3be16b6107 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -21,7 +21,6 @@ from vllm.distributed import get_pp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.models.llama import LlamaModel @@ -1023,12 +1022,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): multimodal_config = vllm_config.model_config.multimodal_config assert multimodal_config, "multimodal_config is required" quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. assert get_pp_group().world_size == 1, "pipeline parallel is not supported" @@ -1055,23 +1052,16 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) def _parse_and_validate_audio_input( self, **kwargs: object diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index c7436cedeb229..97e5537877908 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -458,22 +457,15 @@ class PhiMoEModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + self.vocab_size = config.vocab_size + self.config = config self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -634,35 +626,23 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = vllm_config.quant_config self.model = PhiMoEModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=None, bias=True, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 6427ccfccc134..ece1c5ec23cff 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -751,12 +750,10 @@ class Plamo2Model(torch.nn.Module): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( @@ -827,20 +824,16 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.vocab_size = self.config.vocab_size - self.unpadded_vocab_size = self.config.vocab_size - num_embeddings = ((self.vocab_size + 15) // 16) * 16 self.lm_head = ParallelLMHead( - num_embeddings, + self.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=f"{prefix}.lm_head", ) if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size + config.vocab_size, self.config.vocab_size ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b26546647ce76..cdf32c6c51373 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -477,10 +477,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen2Model( diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e2ba0e262cf79..c5582218b852a 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -43,10 +43,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen2Model( diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 563d3cc23d726..f689ff79d7617 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -272,10 +272,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen3Model( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ddb8693c16e23..9cd342caacb06 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -59,7 +59,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -967,22 +966,17 @@ class Qwen3NextModel(nn.Module): config: Qwen3NextConfig = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config + eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): @@ -1196,7 +1190,7 @@ class Qwen3NextForCausalLM( self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, ( "Qwen3Next currently does not support prefix caching" @@ -1209,23 +1203,13 @@ class Qwen3NextForCausalLM( self.model = Qwen3NextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 271b76adcff7e..9a552db029ee9 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -48,17 +47,12 @@ class Qwen3NextMultiTokenPredictor(nn.Module): model_config = vllm_config.model_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) @@ -66,7 +60,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.fc = ColumnParallelLinear( @@ -252,17 +245,13 @@ class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts): self.model = Qwen3NextMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") ) - self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 97d4667d82e99..d880e6015e5d6 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1136,10 +1136,8 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM): super(Qwen3ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config.text_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 641160295afb3..04da19a440a16 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -440,10 +440,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = SeedOssModel( diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f0dfce7bc7b64..5b8bf150edf6d 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -277,24 +276,18 @@ class SolarModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) else: self.embed_tokens = PPMissingLayer() @@ -455,9 +448,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = SolarModel( @@ -465,18 +458,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -485,7 +469,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d147237808c2a..4cdc90b1f5cb9 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -319,22 +318,17 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=f"{prefix}.lm_head", ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index a2a1bfd30d8d8..381b3f4932e55 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -400,28 +399,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ): super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config self.vllm_config = vllm_config self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py index 7f7b15a5675a3..42fd11117c737 100644 --- a/vllm/model_executor/models/transformers/causal.py +++ b/vllm/model_executor/models/transformers/causal.py @@ -42,7 +42,6 @@ class CausalMixin(VllmModelForTextGeneration): self.skip_prefixes.append("lm_head.") if self.pp_group.is_last_rank: - self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, self.text_config.hidden_size, @@ -56,7 +55,7 @@ class CausalMixin(VllmModelForTextGeneration): logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + self.text_config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ccfe1871ef075..502783b1fd932 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -890,7 +890,7 @@ class WhisperForConditionalGeneration( self.dtype = vllm_config.model_config.dtype self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) - self.unpadded_vocab_size = config.vocab_size + self.proj_out = ParallelLMHead( config.vocab_size, config.d_model, @@ -899,9 +899,7 @@ class WhisperForConditionalGeneration( ) self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) def forward( self, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index bc1351600a2f4..bf3107525bc53 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -692,19 +691,13 @@ class Zamba2Model(nn.Module): assert not is_lora_enabled self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size # Initialize token embeddings self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) # Map hybrid layer indices to block indices @@ -911,7 +904,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC (not supported by Mamba) """ config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -919,9 +912,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC self.vllm_config = vllm_config self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size # Initialize core model self.model = Zamba2Model( @@ -930,23 +920,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC # Initialize language modeling head self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Initialize logits processing and sampling - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. From df4d3a44a83681feea723cc4c4ebe9085d29d58d Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Tue, 11 Nov 2025 11:16:47 -0800 Subject: [PATCH 34/98] [TPU] Rename path to tpu platform (#28452) Signed-off-by: Kyuyeun Kim --- vllm/platforms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index badf72de4a90f..a45ca988200d2 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -38,7 +38,7 @@ def tpu_platform_plugin() -> str | None: # Check for Pathways TPU proxy if envs.VLLM_TPU_USING_PATHWAYS: logger.debug("Confirmed TPU platform is available via Pathways proxy.") - return "tpu_inference.platforms.tpu_jax.TpuPlatform" + return "tpu_inference.platforms.tpu_platform.TpuPlatform" # Check for libtpu installation try: From d4902ba56d9b265698fb53f2d956117454945371 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 12 Nov 2025 06:28:07 +0800 Subject: [PATCH 35/98] [Misc] Cleanup Executor interface (#28441) Signed-off-by: wangxiyuan --- vllm/v1/executor/abstract.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 1e913876b7635..db8303fcec501 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -294,12 +294,6 @@ class Executor(ABC): """Reset the multi-modal cache in each worker.""" self.collective_rpc("reset_mm_cache") - def start_profile(self) -> None: - self.collective_rpc("start_profile") - - def stop_profile(self) -> None: - self.collective_rpc("stop_profile") - def sleep(self, level: int = 1): if self.is_sleeping: logger.warning("Executor is already sleeping.") From 28534b92b9f002e56d4e31d02ca59a070cdad468 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 17:53:59 -0500 Subject: [PATCH 36/98] Add Zurich vLLM Meetup (#28488) Signed-off-by: mgoin --- README.md | 1 + docs/community/meetups.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index b5e230e4b9b07..033e1035d8916 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). - [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing). diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 7ddd45799789c..3fca4659e284a 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Zurich Meetup](https://luma.com/0gls27kb), November 6th 2025. [[Slides]](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) [[Recording]](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w), November 1st 2025. [[Slides]](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link) - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg), October 25th 2025. [[Slides]](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6) - [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing) From e5f599d4d1cfd34a5216cf0733d152ea42073f28 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 18:16:12 -0500 Subject: [PATCH 37/98] [Bugfix] Disable shared expert overlap if Marlin MoE is used (#28410) Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++++ .../layers/fused_moe/shared_fused_moe.py | 10 +++++----- vllm/model_executor/layers/quantization/awq_marlin.py | 1 + .../compressed_tensors/compressed_tensors_moe.py | 1 + vllm/model_executor/layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/layers/quantization/mxfp4.py | 1 + 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e198322ba7a89..615da58eeda28 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -678,6 +678,10 @@ class FusedMoE(CustomOp): and self.moe_config.use_flashinfer_cutlass_kernels ) + @property + def use_marlin_kernels(self): + return getattr(self.quant_method, "use_marlin", False) + @property def use_dp_chunking(self) -> bool: return ( diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 06112ca51b6d5..6ec8b33ed9309 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,17 +28,17 @@ class SharedFusedMoE(FusedMoE): super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are using eplb, because of - # correctness issues, or if using flashinfer with DP, since there - # is nothing to be gained in this case. Disabling the overlap - # optimization also prevents the shared experts from being hidden - # from torch.compile. + # Disable shared expert overlap if: + # - we are using eplb, because of correctness issues + # - we are using flashinfer with DP, since there nothint to gain + # - we are using marlin kjernels self.use_overlapped = ( use_overlapped and not ( # TODO(wentao): find the root cause and remove this condition self.enable_eplb or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or self.use_marlin_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 3e1f87b59a34d..3f6ea68072b40 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -424,6 +424,7 @@ class AWQMoEMethod(FusedMoEMethodBase): if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6257a410e9432..f1050c15f79e7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1342,6 +1342,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): f"{WNA16_SUPPORTED_BITS}", ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 42a569e7770c0..68a122fd46c6b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -482,6 +482,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): self.quant_type = scalar_types.uint8b128 else: raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8d7297a0a1b3b..7940b359a150c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) From 412e153df557bbae541363ac4abde879a6d84488 Mon Sep 17 00:00:00 2001 From: Max Hu Date: Tue, 11 Nov 2025 18:32:20 -0500 Subject: [PATCH 38/98] [Feature] Allow configuring FlashInfer workspace size (#28269) Signed-off-by: Max Hu Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/envs.py | 6 ++++++ vllm/v1/attention/backends/flashinfer.py | 6 +++--- vllm/v1/attention/backends/mla/common.py | 16 +++++++--------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 52a9671bc46e2..5274c8ba1b24e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -159,6 +159,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency" + VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1237,6 +1238,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"] ), + # Control the workspace buffer size for the FlashInfer backend. + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( + os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024)) + ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. @@ -1583,6 +1588,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 07a0ab41a9e05..18bbc3cc3c12b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,6 +16,7 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -55,7 +56,6 @@ from vllm.v1.attention.backends.utils import ( ) from vllm.v1.kv_cache_interface import AttentionSpec -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() @@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" ) return trtllm_gen_workspace_buffer @@ -414,7 +414,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_workspace_buffer(self): if self._workspace_buffer is None: - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE + buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE if vllm_is_batch_invariant(): buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 19bd102cb1e30..467c01cd9d069 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -196,8 +196,8 @@ from typing import ClassVar, Generic, TypeVar import torch from tqdm import tqdm -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, @@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ) -# Currently 394MB, this can be tuned based on GEMM sizes used. -# Chosen to be the same as sglang: -# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 -FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 - - class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -590,7 +584,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None @@ -602,7 +598,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_trtllm_ragged_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) if self._use_cudnn_prefill: From d23539549a6db54ab152ce4e566c31f6891ddab5 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Tue, 11 Nov 2025 18:34:58 -0600 Subject: [PATCH 39/98] Use FLASHINFER MLA backend when testing fp8_kv_scale_compile (#28491) Signed-off-by: adabeyta --- tests/compile/test_full_graph.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 71f90f6d8d3ee..b4e5e56ac9fe6 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -10,6 +10,7 @@ import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -184,13 +185,24 @@ def test_custom_compile_config( [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) @pytest.mark.parametrize( - "model", + "model, backend", [ - "Qwen/Qwen2-0.5B", # Standard attention model - "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ("Qwen/Qwen2-0.5B", None), # Standard attention model + ( + "deepseek-ai/DeepSeek-V2-Lite", + AttentionBackendEnum.FLASHINFER_MLA, + ), # MLA (Multi-head Latent Attention) model ], ) -def test_fp8_kv_scale_compile(compilation_mode: int, model: str): +def test_fp8_kv_scale_compile( + monkeypatch: pytest.MonkeyPatch, + compilation_mode: int, + model: str, + backend: AttentionBackendEnum | None, +): + if backend: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", From 1788aa1efb1f3cd8bf521885244aed3b89bed8a1 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Wed, 12 Nov 2025 01:41:54 +0100 Subject: [PATCH 40/98] [BugFix] Graceful handling of torch symm mem errors. (#27671) Signed-off-by: ilmarkov Co-authored-by: Michael Goin --- .../device_communicators/symm_mem.py | 22 +++++++++++++------ vllm/envs.py | 4 ++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 74d6fb40c83b7..eb1f173b11925 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -88,13 +88,21 @@ class SymmMemCommunicator: self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ self.world_size ] - - self.buffer = torch_symm_mem.empty( - self.max_size // self.dtype.itemsize, - device=self.device, - dtype=self.dtype, - ) - handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + try: + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + except RuntimeError as e: + logger.warning_once( + "SymmMemCommunicator: symmetric memory initialization failed: %s " + "Communicator is not available. To suppress this warning set " + "VLLM_ALLREDUCE_USE_SYMM_MEM=0", + str(e), + ) + return if handle.multicast_ptr == 0: logger.warning( "SymmMemCommunicator: symmetric memory " diff --git a/vllm/envs.py b/vllm/envs.py index 5274c8ba1b24e..46725efac70ef 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -201,7 +201,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False @@ -1389,7 +1389,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( - int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0")) + int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), From 48c879369f83ab1ab281a4bfe97f9a54790715d1 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 11 Nov 2025 16:46:18 -0800 Subject: [PATCH 41/98] [Frontend] Change CompilationMode to a proper Enum (#28165) Signed-off-by: Yanan Cao --- tests/compile/test_basic_correctness.py | 6 ++- tests/utils_/test_argparse_utils.py | 60 +++++++++++++++++++++++++ vllm/compilation/wrapper.py | 4 +- vllm/config/compilation.py | 51 ++++++++++++++------- vllm/config/vllm.py | 5 +-- vllm/entrypoints/llm.py | 5 ++- 6 files changed, 108 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 132a838b8d44c..3f6898607f6b9 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -127,7 +127,9 @@ def test_compile_correctness( CompilationMode.VLLM_COMPILE, ]: for mode in [CompilationMode.NONE, comp_mode]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) + all_args.append( + final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"] + ) # inductor will change the output, so we only compare if the output # is close, not exactly the same. @@ -146,7 +148,7 @@ def test_compile_correctness( CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) + all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"]) all_envs.append({}) all_envs.append({}) diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py index 51684edcc8a30..3310753d2b6d6 100644 --- a/tests/utils_/test_argparse_utils.py +++ b/tests/utils_/test_argparse_utils.py @@ -8,6 +8,7 @@ import os import pytest import yaml from transformers import AutoTokenizer +from pydantic import ValidationError from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens @@ -376,6 +377,65 @@ def test_load_config_file(tmp_path): os.remove(str(config_file_path)) +def test_compilation_mode_string_values(parser): + """Test that -O.mode accepts both integer and string mode values.""" + args = parser.parse_args(["-O.mode", "0"]) + assert args.compilation_config == {"mode": 0} + + args = parser.parse_args(["-O3"]) + assert args.compilation_config == {"mode": 3} + + args = parser.parse_args(["-O.mode=NONE"]) + assert args.compilation_config == {"mode": "NONE"} + + args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"]) + assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"} + + args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"]) + assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"} + + args = parser.parse_args(["-O.mode", "VLLM_COMPILE"]) + assert args.compilation_config == {"mode": "VLLM_COMPILE"} + + args = parser.parse_args(["-O.mode=none"]) + assert args.compilation_config == {"mode": "none"} + + args = parser.parse_args(["-O.mode=vllm_compile"]) + assert args.compilation_config == {"mode": "vllm_compile"} + + +def test_compilation_config_mode_validator(): + """Test that CompilationConfig.mode field validator converts strings to integers.""" + from vllm.config.compilation import CompilationConfig, CompilationMode + + config = CompilationConfig(mode=0) + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode=3) + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="NONE") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="STOCK_TORCH_COMPILE") + assert config.mode == CompilationMode.STOCK_TORCH_COMPILE + + config = CompilationConfig(mode="DYNAMO_TRACE_ONCE") + assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE + + config = CompilationConfig(mode="VLLM_COMPILE") + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="none") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="vllm_compile") + assert config.mode == CompilationMode.VLLM_COMPILE + + with pytest.raises(ValidationError, match="Invalid compilation mode"): + CompilationConfig(mode="INVALID_MODE") + + def test_flat_product(): # Check regular itertools.product behavior result1 = list(flat_product([1, 2, 3], ["a", "b"])) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4b10c85209f63..4d26619bd128c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__( - self, compiled_callable: Callable | None = None, compilation_mode: int = 0 + self, + compiled_callable: Callable | None = None, + compilation_mode: CompilationMode = CompilationMode.NONE, ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9c9557df4e738..e1d60ee84d89c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -28,7 +28,7 @@ else: logger = init_logger(__name__) -class CompilationMode: +class CompilationMode(enum.IntEnum): """The compilation approach used for torch.compile-based compilation of the model.""" @@ -115,7 +115,7 @@ class PassConfig: """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a float in MB. - Unspecified will fallback to default values + Unspecified will fallback to default values which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { 90: { @@ -244,7 +244,7 @@ class CompilationConfig: Please use mode. Currently all levels are mapped to mode. """ # Top-level Compilation control - mode: int | None = None + mode: CompilationMode | None = None """The compilation approach used for torch.compile-based compilation of the model. @@ -377,23 +377,23 @@ class CompilationConfig: FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. Generally for performance FULL_AND_PIECEWISE is better. - + FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. Mixed prefill-decode batches are run without cudagraphs. Can be good for decode instances in a P/D setup where prefill is not as important so we can save some memory. - + FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. - Note that the cudagraph logic is generally orthogonal to the - compilation logic. While piecewise cudagraphs require piecewise + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full cudagraphs are supported with and without compilation. - - Warning: This flag is new and subject to change in addition + + Warning: This flag is new and subject to change in addition more modes may be added. """ use_cudagraph: bool = True @@ -422,7 +422,7 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. + internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE. """ full_cuda_graph: bool | None = False @@ -451,7 +451,7 @@ class CompilationConfig: outside the partition functions. For a graph with N cudagraph-unsafe ops (e.g., Attention), there would be N+1 partitions. To mark an op as cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when - register the custom op. + register the custom op. This config supports both full cudagraph and piecewise cudagraph without compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper @@ -468,8 +468,8 @@ class CompilationConfig: max_cudagraph_capture_size: int | None = field(default=None) """The maximum cudagraph capture size. - - If cudagraph_capture_sizes is specified, this will be set to the largest + + If cudagraph_capture_sizes is specified, this will be set to the largest size in that list (or checked for consistency if specified). If cudagraph_capture_sizes is not specified, the list of sizes is generated automatically following the pattern: @@ -478,7 +478,7 @@ class CompilationConfig: range(256, max_cudagraph_capture_size + 1, 16)) If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, - 512) by default. This voids OOM in tight memory scenarios with small + 512) by default. This voids OOM in tight memory scenarios with small max_num_seqs, and prevents capture of many large graphs (>512) that would greatly increase startup time with limited performance benefit. """ @@ -579,6 +579,27 @@ class CompilationConfig: __str__ = __repr__ + @field_validator("mode", mode="before") + @classmethod + def validate_mode_before(cls, value: Any) -> Any: + """ + Enable parsing the `mode` field from string mode names. + Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE, + DYNAMO_TRACE_ONCE, VLLM_COMPILE. + """ + if isinstance(value, str): + # Convert string mode name to integer value + mode_name = value.upper() + + if mode_name not in CompilationMode.__members__: + raise ValueError( + f"Invalid compilation mode: {value}. " + f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}" + ) + + return CompilationMode[mode_name] + return value + @field_validator("cudagraph_mode", mode="before") @classmethod def validate_cudagraph_mode_before(cls, value: Any) -> Any: @@ -904,7 +925,7 @@ class CompilationConfig: return self.mode == CompilationMode.VLLM_COMPILE # Inductor partition case - return self.backend == "inductor" and self.mode > CompilationMode.NONE + return self.backend == "inductor" and self.mode != CompilationMode.NONE def custom_op_log_check(self): """ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0fca967d90838..df9a1fd08af6f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -422,16 +422,13 @@ class VllmConfig: self.compilation_config.mode = CompilationMode.VLLM_COMPILE else: self.compilation_config.mode = CompilationMode.NONE - else: - assert self.compilation_config.mode >= CompilationMode.NONE - assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE # If user does not set custom ops via none or all set it here based on # compilation mode and backend. if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if ( self.compilation_config.backend == "inductor" - and self.compilation_config.mode > CompilationMode.NONE + and self.compilation_config.mode != CompilationMode.NONE ): self.compilation_config.custom_ops.append("none") else: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 22fe2ae9280aa..62717a7eacdf0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,6 +23,7 @@ from vllm.config import ( StructuredOutputsConfig, is_init_field, ) +from vllm.config.compilation import CompilationMode from vllm.config.model import ( ConvertOption, HfOverrides, @@ -259,7 +260,9 @@ class LLM: if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(mode=compilation_config) + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( **{ From 3f770f4427cb926c24af540cc72d1b5901f7f702 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 12 Nov 2025 08:49:29 +0800 Subject: [PATCH 42/98] [Performance] Cache loaded custom logitsprocs to avoid overheads (#28462) Signed-off-by: Isotr0py --- vllm/v1/sample/logits_processor/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index eb537eae6c904..5992c4066c9cb 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -5,7 +5,7 @@ import inspect import itertools from abc import abstractmethod from collections.abc import Sequence -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING import torch @@ -216,11 +216,17 @@ def build_logitsprocs( ) +cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs) + + def validate_logits_processors_parameters( logits_processors: Sequence[str | type[LogitsProcessor]] | None, sampling_params: SamplingParams, ): - for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_processors = ( + tuple(logits_processors) if logits_processors is not None else None + ) + for logits_procs in cached_load_custom_logitsprocs(logits_processors): logits_procs.validate_params(sampling_params) From e1710393c44cff20e481b632b86d157a9d694625 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 12 Nov 2025 09:22:16 +0800 Subject: [PATCH 43/98] [[V0 deprecation]]Remove VLLM_USE_V1 env (#28204) Signed-off-by: wangxiyuan --- .../scripts/hardware_ci/run-cpu-test.sh | 2 +- examples/offline_inference/mlpspeculator.py | 3 +- .../offline_inference/qwen2_5_omni/README.md | 2 - .../qwen2_5_omni/only_thinker.py | 7 +-- .../others/lmcache/cpu_offload_lmcache.py | 43 ++++++------------- tests/entrypoints/openai/test_orca_metrics.py | 3 -- vllm/envs.py | 13 ------ vllm/usage/usage_lib.py | 1 - 8 files changed, 15 insertions(+), 59 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7927aef19e4eb..7e0f720feaa71 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -76,7 +76,7 @@ function cpu_tests() { # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -x -s -v \ + # pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index d5b1b4ad29a92..6a533eb5c937f 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -4,8 +4,7 @@ This file demonstrates the usage of text generation with an LLM model, comparing the performance with and without speculative decoding. -Note that still not support `v1`: -VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py +Note that this example is out of date and not supported in vLLM v1. """ import gc diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 16d44cbadbc98..d8fb50d7fe55c 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ # Read vision and audio inputs from a single video file # NOTE: V1 engine does not support interleaved modalities yet. -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q use_audio_in_video # Multiple audios -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q multi_audios ``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 6fbe1303f431a..ed005e6a69b80 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -7,7 +7,6 @@ with the correct prompt format on Qwen2.5-Omni (thinker only). from typing import NamedTuple -import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult: ) asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) - assert not envs.VLLM_USE_V1, ( - "V1 does not support use_audio_in_video. " - "Please launch this example with " - "`VLLM_USE_V1=0`." - ) + return QueryResult( inputs={ "prompt": prompt, diff --git a/examples/others/lmcache/cpu_offload_lmcache.py b/examples/others/lmcache/cpu_offload_lmcache.py index e10ee4e2a9a9a..53036b3eb0ff3 100644 --- a/examples/others/lmcache/cpu_offload_lmcache.py +++ b/examples/others/lmcache/cpu_offload_lmcache.py @@ -37,7 +37,7 @@ from vllm.config import KVTransferConfig from vllm.engine.arg_utils import EngineArgs -def setup_environment_variables(vllm_version: str): +def setup_environment_variables(): # LMCache-related environment variables # Use experimental features in LMCache os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" @@ -47,12 +47,10 @@ def setup_environment_variables(vllm_version: str): os.environ["LMCACHE_LOCAL_CPU"] = "True" # Set local CPU memory limit to 5.0 GB os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" - if vllm_version == "v0": - os.environ["VLLM_USE_V1"] = "0" @contextlib.contextmanager -def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str): +def build_llm_with_lmcache(lmcache_connector: str, model: str): ktc = KVTransferConfig( kv_connector=lmcache_connector, kv_role="kv_both", @@ -60,21 +58,12 @@ def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392). - if vllm_version == "v0": - llm_args = EngineArgs( - model=model, - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enable_chunked_prefill=True, # Only in v0 - ) - else: - llm_args = EngineArgs( - model=model, - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - ) + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + ) llm = LLM(**asdict(llm_args)) try: @@ -116,18 +105,10 @@ def parse_args(): def main(): - args = parse_args() - - if args.version == "v0": - lmcache_connector = "LMCacheConnector" - model = "mistralai/Mistral-7B-Instruct-v0.2" - else: - lmcache_connector = "LMCacheConnectorV1" - model = "meta-llama/Meta-Llama-3.1-8B-Instruct" - - setup_environment_variables(args.version) - - with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: + lmcache_connector = "LMCacheConnectorV1" + model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + setup_environment_variables() + with build_llm_with_lmcache(lmcache_connector, model) as llm: # This example script runs two requests with a shared prefix. # Define the shared prompt and specific prompts shared_prompt = "Hello, how are you?" * 1000 diff --git a/tests/entrypoints/openai/test_orca_metrics.py b/tests/entrypoints/openai/test_orca_metrics.py index d32cfde07c21e..1ed44a33bf81f 100644 --- a/tests/entrypoints/openai/test_orca_metrics.py +++ b/tests/entrypoints/openai/test_orca_metrics.py @@ -22,9 +22,6 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[True]) def server(request, monkeypatch_module): - use_v1 = request.param - monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - args = [ "--dtype", "bfloat16", diff --git a/vllm/envs.py b/vllm/envs.py index 46725efac70ef..2aa6afcabf288 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -100,7 +100,6 @@ if TYPE_CHECKING: VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_PYNCCL: bool = False - VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True @@ -884,8 +883,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") ), - # If set, use the V1 code path. - "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. "VLLM_ROCM_USE_AITER": lambda: ( @@ -1538,16 +1535,6 @@ def is_set(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def set_vllm_use_v1(use_v1: bool): - if is_set("VLLM_USE_V1"): - raise ValueError( - "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " - "explicitly by the user. Please raise this as a Github " - "Issue and explicitly set VLLM_USE_V1=0 or 1." - ) - os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" - - def compute_hash() -> str: """ WARNING: Whenever a new key is added to this environment diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index c8bff8b7c80b6..4eddaf56d81ad 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -42,7 +42,6 @@ _USAGE_ENV_VARS_TO_COLLECT = [ "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_PP_LAYER_PARTITION", "VLLM_USE_TRITON_AWQ", - "VLLM_USE_V1", "VLLM_ENABLE_V1_MULTIPROCESSING", ] From 7f829be7d3d734020606fcca520f3c500581beb8 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 12 Nov 2025 09:43:06 +0800 Subject: [PATCH 44/98] [CPU] Refactor CPU attention backend (#27954) Signed-off-by: jiang1.li --- .buildkite/release-pipeline.yaml | 2 +- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 28 +- csrc/cpu/attention.cpp | 798 ------- csrc/cpu/cache.cpp | 214 -- csrc/cpu/cpu_attn.cpp | 249 +++ csrc/cpu/cpu_attn_amx.hpp | 511 +++++ csrc/cpu/cpu_attn_impl.hpp | 1977 +++++++++++++++++ csrc/cpu/cpu_attn_macros.h | 63 + csrc/cpu/cpu_attn_vec.hpp | 248 +++ csrc/cpu/cpu_attn_vec16.hpp | 171 ++ csrc/cpu/cpu_types_x86.hpp | 50 +- csrc/cpu/dnnl_helper.cpp | 18 +- csrc/cpu/dnnl_helper.h | 24 - csrc/cpu/scratchpad_manager.cpp | 23 + csrc/cpu/scratchpad_manager.h | 31 + csrc/cpu/shm.cpp | 2 +- csrc/cpu/torch_bindings.cpp | 105 +- docker/Dockerfile.cpu | 4 + docs/getting_started/installation/cpu.md | 2 + .../attention/test_attention_selector.py | 6 +- tests/kernels/attention/test_cpu_attn.py | 575 +++++ tests/kernels/test_onednn.py | 1 - .../models/language/generation/test_common.py | 17 +- .../models/language/pooling/test_embedding.py | 3 +- tests/models/registry.py | 4 +- vllm/_custom_ops.py | 82 + vllm/attention/backends/registry.py | 3 +- vllm/engine/arg_utils.py | 3 - vllm/platforms/cpu.py | 37 +- vllm/utils/__init__.py | 1 - vllm/v1/attention/backends/cpu_attn.py | 985 +++----- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/worker/cpu_model_runner.py | 14 +- 34 files changed, 4354 insertions(+), 1902 deletions(-) delete mode 100644 csrc/cpu/attention.cpp delete mode 100644 csrc/cpu/cache.cpp create mode 100644 csrc/cpu/cpu_attn.cpp create mode 100644 csrc/cpu/cpu_attn_amx.hpp create mode 100644 csrc/cpu/cpu_attn_impl.hpp create mode 100644 csrc/cpu/cpu_attn_macros.h create mode 100644 csrc/cpu/cpu_attn_vec.hpp create mode 100644 csrc/cpu/cpu_attn_vec16.hpp create mode 100644 csrc/cpu/scratchpad_manager.cpp create mode 100644 csrc/cpu/scratchpad_manager.h create mode 100644 tests/kernels/attention/test_cpu_attn.py diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 12f730738b8a5..38c400ba1faf5 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -132,7 +132,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7e0f720feaa71..7479c43977d78 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -49,6 +49,7 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test @@ -116,4 +117,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index dbda19fbcbf20..51447cde0b294 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -15,6 +15,7 @@ endif() # set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) +set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16}) include_directories("${CMAKE_SOURCE_DIR}/csrc") @@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) set(ENABLE_AVX512VNNI OFF) message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") endif() + + find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND) + if (AMXBF16_FOUND OR ENABLE_AMXBF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile") + set(ENABLE_AMXBF16 ON) + add_compile_definitions(-DCPU_CAPABILITY_AMXBF16) + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.") + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size FetchContent_MakeAvailable(oneDNN) + set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE}) add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") target_include_directories( dnnl_ext @@ -305,14 +325,14 @@ endif() # set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" - "csrc/cpu/attention.cpp" - "csrc/cpu/cache.cpp" "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp" - "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/scratchpad_manager.cpp" + "csrc/cpu/torch_bindings.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp deleted file mode 100644 index 82862fea7f2be..0000000000000 --- a/csrc/cpu/attention.cpp +++ /dev/null @@ -1,798 +0,0 @@ -#include "cpu_types.hpp" - -namespace { - -template -struct KernelVecType { - using q_load_vec_type = void; - using q_vec_type = void; - using k_load_vec_type = void; - using k_vec_type = void; - using qk_acc_vec_type = void; - using v_load_vec_type = void; -}; - -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::FP32Vec4; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -}; - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power and s390x architecture-specific vector types - using q_load_vec_type = vec_op::FP32Vec8; - using k_load_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures, including x86 - using q_load_vec_type = vec_op::FP16Vec8; - using k_load_vec_type = vec_op::FP16Vec16; - using v_load_vec_type = vec_op::FP16Vec16; -#endif - using q_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; -}; - -#ifdef __AVX512BF16__ -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::BF16Vec32; - using k_load_vec_type = vec_op::BF16Vec32; - using k_vec_type = vec_op::BF16Vec32; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; -#else - #ifdef __aarch64__ - #ifndef ARM_BF16_SUPPORT - // pass - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif -#endif - -template -FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, - const int capacity) { - T max = data[0]; - for (int i = 1; i < size; ++i) { - max = max >= data[i] ? max : data[i]; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, - const int capacity, - const float alibi_slope, - const int start_index, - const int seq_len) { - data[0] += alibi_slope * (start_index - seq_len + 1); - T max = data[0]; - for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); - data[i] = qk; - max = max >= qk ? max : qk; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, - const int size) { - T max = max_data[0]; - for (int i = 1; i < size; ++i) { - max = max >= max_data[i] ? max : max_data[i]; - } - - T rescaled_sum = 0; - for (int i = 0; i < size; ++i) { - T rescale_factor = std::exp(max_data[i] - max); - rescaled_sum += rescale_factor * sum_data[i]; - sum_data[i] *= rescale_factor; - } - for (int i = 0; i < size; ++i) { - sum_data[i] /= rescaled_sum + 1e-8; - } -} - -template -struct reduceQKBlockKernel { - using q_load_vec_type = typename KernelVecType::q_load_vec_type; - using q_vec_type = typename KernelVecType::q_vec_type; - using k_load_vec_type = typename KernelVecType::k_load_vec_type; - using k_vec_type = typename KernelVecType::k_vec_type; - using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; - - constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; - constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; - constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; - - static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); - static_assert(k_load_vec_type::get_elem_num() % x == 0); - static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, - float* __restrict__ logits, float scale, - const int token_num) { - const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; - - qk_acc_vec_type group_accums[MAX_GROUP_NUM]; - if (token_num == BLOCK_SIZE) { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - - vec_op::unroll_loop( - [k_block, &q_group_vec, &group_accums](int token_group_idx) { - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } else { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - for (int token_group_start = 0; token_group_start < group_num; - token_group_start += UNROLL_GROUP_NUM) { - vec_op::unroll_loop( - [token_group_start, k_block, &q_group_vec, - &group_accums](int token_group_idx) { - token_group_idx += token_group_start; - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } - } - - for (int token_group_idx = 0; token_group_idx < group_num; - ++token_group_idx) { - vec_op::unroll_loop( - [&group_accums, logits, scale, token_group_idx](int token_idx) { - float dot_v = - group_accums[token_group_idx] - .template reduce_sub_sum(token_idx); - logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = - dot_v * scale; - }); - } - } -}; - -template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, - acc_t&& acc) { - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); - static_assert(BLOCK_SIZE == ELEM_NUM); - vec_op::FP32Vec16 prob_vec(prob); - - vec_op::unroll_loop([&](int head_elem_idx) { - v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); - vec_op::FP32Vec16 fp32_v_vec(v_vec); - acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; - }); -} -}; // namespace - -// Paged attention v1 -namespace { -template -struct paged_attention_v1_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - - int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); - - const int parallel_work_item_num = omp_get_max_threads(); - - size_t logits_bytes = - parallel_work_item_num * max_seq_len_padded * sizeof(float); - float* logits = (float*)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] - -#pragma omp parallel for collapse(2) schedule(dynamic, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int seq_len = seq_lens[seq_idx]; - const int* seq_block_table = - block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; - float* __restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_seq_len_padded; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - thread_block_logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - // Compute softmax - if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, seq_len, - block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - seq_len); - } else { - reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - std::free(logits); - } -}; - -#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v1_impl::call( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ - num_heads); - -template -void paged_attention_v1_impl_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes); - -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) - CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) - }); -} - -// Paged attention v2 -namespace { -template -struct paged_attention_v2_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads, const int max_num_partitions) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); - static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); - -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int partition_idx = 0; partition_idx < max_num_partitions; - ++partition_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int start_token_idx = partition_idx * PARTITION_SIZE; - - if (start_token_idx >= seq_len) continue; - - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - const bool no_reduce = (partition_num == 1); - const int token_num = - (std::min(seq_len, start_token_idx + PARTITION_SIZE) - - start_token_idx); - const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int last_block_token_num = - token_num - (block_num - 1) * BLOCK_SIZE; - const int* seq_block_table = block_tables + - max_num_blocks_per_seq * seq_idx + - start_token_idx / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - - float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - std::pair max_and_sum; - if (alibi_slopes) { - max_and_sum = reduceSoftmaxAlibi( - logits, token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, seq_len); - } else { - max_and_sum = - reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); - } - - auto&& [max_logit, exp_sum] = max_and_sum; - - scalar_t* __restrict__ output_buffer = nullptr; - if (!no_reduce) { - auto idx = seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - max_logits[idx] = max_logit; - exp_sums[idx] = exp_sum; - output_buffer = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - output_buffer = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - output_buffer + head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - } - - // Rescale partition softmax and store the factors to exp_sums -#pragma omp parallel for collapse(2) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - reducePartitionSoftmax( - max_logits + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - partition_num); - } - } - - // Reduce values - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); - constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some - // HEAD_SIZE didn't align with 64 bytes - static_assert(HEAD_SIZE % head_elem_num_per_group == 0); - constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float* __restrict__ rescale_factors = exp_sums; -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - const float* __restrict__ seq_head_rescale_factors = - rescale_factors + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - const scalar_t* __restrict__ seq_head_tmp_out = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - group_idx * head_elem_num_per_group; - scalar_t* __restrict__ seq_head_output = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - group_idx * head_elem_num_per_group; - - vec_op::FP32Vec16 acc; - for (int i = 0; i < partition_num; ++i) { - vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); - v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); - vec_op::FP32Vec16 fp32_value(value); - acc = acc + fp32_value * rescale_factor; - } - v_load_vec_type cast_acc(acc); - cast_acc.save(seq_head_output); - } - } - } - } -}; - -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ - max_num_partitions); - -template -void paged_attention_v2_impl_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - int max_num_partitions = exp_sums.size(-1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ - alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) - CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) - }); -} \ No newline at end of file diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp deleted file mode 100644 index 69f6d06e3c967..0000000000000 --- a/csrc/cpu/cache.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include - -#include "cpu_types.hpp" - -#if defined(__x86_64__) - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 -#else - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES -#endif - -namespace { -template -void copy_blocks_cpu_impl(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, - const int layer_num) { - const size_t pair_num = mapping_pairs.size(0); - const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; -#pragma omp parallel for collapse(2) - for (int layer = 0; layer < layer_num; ++layer) { - for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = - element_num_per_block * mapping_pairs[pair][0].item(); - int64_t target_offset = - element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t* source_ptr = key_cache_ptr + source_offset; - scalar_t* target_ptr = key_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - - scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); - source_ptr = value_cache_ptr + source_offset; - target_ptr = value_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - } - } -} - -template -void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x) { - const int block_elem_num = num_heads * head_size * block_size; - -#pragma omp parallel for collapse(2) - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx >= 0) { - int src_key_head_idx = token_idx * key_stride + head_idx * head_size; - int src_value_head_idx = - token_idx * value_stride + head_idx * head_size; - const scalar_t* src_key_head_ptr = key + src_key_head_idx; - const scalar_t* src_value_head_ptr = value + src_value_head_idx; - const int64_t block_index = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - - for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { - const int64_t target_offset = - src_key_idx * block_size + block_offset * x; - for (int i = 0; i < x; ++i) { - target_key_head_ptr[target_offset + i] = - src_key_head_ptr[src_key_idx + i]; - } - } - - for (int src_value_idx = 0; src_value_idx < head_size; - ++src_value_idx) { - const int64_t target_offset = - src_value_idx * block_size + block_offset; - target_value_head_ptr[target_offset] = - src_value_head_ptr[src_value_idx]; - } - } - } - } -} -}; // namespace - -template -void concat_and_cache_mla_cpu_impl( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int num_tokens, // - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size // -) { -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - continue; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, - scalar_t* __restrict__ dst, int src_stride, int dst_stride, - int size, int offset) { - for (int i = 0; i < size; i++) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - dst[dst_idx] = src[src_idx]; - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); - } -} - -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - unsigned num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - - const int element_num_per_block = key_caches[0][0].numel(); - DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); -} - -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, - num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); -} - -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - TORCH_CHECK(kv_cache_dtype != "fp8"); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - VLLM_DISPATCH_FLOATING_TYPES( - kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl) - concat_and_cache_mla_cpu_impl( - kv_c.data_ptr(), k_pe.data_ptr(), - kv_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride, - kv_lora_rank, pe_dim, block_size); - CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl) - }); -} - -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { - TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") -} diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp new file mode 100644 index 0000000000000..50f17c758c148 --- /dev/null +++ b/csrc/cpu/cpu_attn.cpp @@ -0,0 +1,249 @@ +#include "cpu_attn_vec.hpp" +#include "cpu_attn_vec16.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu_attn_amx.hpp" + #define AMX_DISPATCH(...) \ + case cpu_attention::ISA::AMX: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: +#endif + +#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ + case HEAD_DIM: { \ + constexpr size_t head_dim = HEAD_DIM; \ + return __VA_ARGS__(); \ + } + +#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ + [&] { \ + switch (HEAD_DIM) { \ + CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ + std::to_string(HEAD_DIM)); \ + } \ + } \ + }() + +#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ + [&] { \ + switch (ISA_TYPE) { \ + AMX_DISPATCH(__VA_ARGS__) \ + case cpu_attention::ISA::VEC: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + case cpu_attention::ISA::VEC16: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ + } \ + } \ + }() + +torch::Tensor get_scheduler_metadata( + const int64_t num_req, const int64_t num_heads_q, + const int64_t num_heads_kv, const int64_t head_dim, + const torch::Tensor& seq_lens, at::ScalarType dtype, + const torch::Tensor& query_start_loc, const bool casual, + const int64_t window_size, const std::string& isa_hint, + const bool enable_kv_split) { + cpu_attention::ISA isa; + if (isa_hint == "amx") { + isa = cpu_attention::ISA::AMX; + } else if (isa_hint == "vec") { + isa = cpu_attention::ISA::VEC; + } else if (isa_hint == "vec16") { + isa = cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); + } + + cpu_attention::AttentionScheduler::ScheduleInput input; + input.num_reqs = num_req; + input.num_heads_q = num_heads_q; + input.num_heads_kv = num_heads_kv; + input.head_dim = head_dim; + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + if (window_size != -1) { + input.left_sliding_window_size = window_size - 1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = window_size - 1; + } + } else { + input.left_sliding_window_size = -1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = -1; + } + } + input.casual = casual; + input.isa = isa; + input.enable_kv_split = enable_kv_split; + TORCH_CHECK(casual, "Only supports casual mask for now."); + + VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; + }); + }); + }); + + cpu_attention::AttentionScheduler scheduler; + torch::Tensor metadata = scheduler.schedule(input); + return metadata; +} + +void cpu_attn_reshape_and_cache( + const torch::Tensor& key, // [token_num, head_num, head_size] + const torch::Tensor& value, // [token_num, head_num, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& slot_mapping, const std::string& isa) { + TORCH_CHECK_EQ(key.dim(), 3); + TORCH_CHECK_EQ(value.dim(), 3); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + TORCH_CHECK_EQ(key.stride(2), 1); + TORCH_CHECK_EQ(value.stride(2), 1); + + const int64_t token_num = key.size(0); + const int64_t key_token_num_stride = key.stride(0); + const int64_t value_token_num_stride = value.stride(0); + const int64_t head_num = value.size(1); + const int64_t key_head_num_stride = key.stride(1); + const int64_t value_head_num_stride = value.stride(1); + const int64_t num_blocks = key_cache.size(0); + const int64_t num_blocks_stride = key_cache.stride(0); + const int64_t cache_head_num_stride = key_cache.stride(1); + const int64_t block_size = key_cache.size(2); + const int64_t block_size_stride = key_cache.stride(2); + const int64_t head_dim = key.size(-1); + + cpu_attention::ISA isa_tag = [&]() { + if (isa == "amx") { + return cpu_attention::ISA::AMX; + } else if (isa == "vec") { + return cpu_attention::ISA::VEC; + } else if (isa == "vec16") { + return cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Invalid ISA type: " + isa); + } + }(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, + key_token_num_stride, value_token_num_stride, head_num, + key_head_num_stride, value_head_num_stride, num_blocks, + num_blocks_stride, cache_head_num_stride, block_size, + block_size_stride); + }); + }); + }); +} + +void cpu_attention_with_kv_cache( + const torch::Tensor& query, // [num_tokens, num_heads, head_size] + const torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& output, // [num_tokens, num_heads, head_size] + const torch::Tensor& query_start_loc, // [num_tokens + 1] + const torch::Tensor& seq_lens, // [num_tokens] + const double scale, const bool causal, + const std::optional& alibi_slopes, // [num_heads] + const int64_t sliding_window_left, const int64_t sliding_window_right, + const torch::Tensor& block_table, // [num_tokens, max_block_num] + const double softcap, const torch::Tensor& scheduler_metadata, + const std::optional& s_aux // [num_heads] +) { + TORCH_CHECK_EQ(query.dim(), 3); + TORCH_CHECK_EQ(query.stride(2), 1); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + + cpu_attention::AttentionInput input; + input.metadata = reinterpret_cast( + scheduler_metadata.data_ptr()); + input.num_tokens = query.size(0); + input.num_heads = query.size(1); + input.num_kv_heads = key_cache.size(1); + input.block_size = key_cache.size(2); + input.query = query.data_ptr(); + input.query_num_tokens_stride = query.stride(0); + input.query_num_heads_stride = query.stride(1); + input.cache_num_blocks_stride = key_cache.stride(0); + input.cache_num_kv_heads_stride = key_cache.stride(1); + input.blt_num_tokens_stride = block_table.stride(0); + input.key_cache = key_cache.data_ptr(); + input.value_cache = value_cache.data_ptr(); + input.output = output.data_ptr(); + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + input.block_table = block_table.data_ptr(); + input.alibi_slopes = + alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; + // For now sink must be bf16 + input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; + input.scale = scale; + input.causal = causal; + input.sliding_window_left = sliding_window_left; + input.sliding_window_right = sliding_window_right; + if (input.causal) { + // to make boundary calculation easier + input.sliding_window_right = 0; + } + float softcap_fp32 = softcap; + input.softcap = softcap_fp32; + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { + CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); + }); + }); +} diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp new file mode 100644 index 0000000000000..8da458b99119c --- /dev/null +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -0,0 +1,511 @@ +#ifndef CPU_ATTN_AMX_HPP +#define CPU_ATTN_AMX_HPP + +#include "cpu_attn_impl.hpp" + +namespace cpu_attention { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + // k_cache, v_cache are prepacked + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + // logits_buffer, output_buffer are not prepacked + float* __restrict__ c_tile_4 = c_tile; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_tile; + float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = scalar_t; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = scalar_t; + + constexpr static int64_t BlockSizeAlignment = + AMX_TILE_ROW_BYTES / + sizeof(kv_cache_t); // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = 32; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::AMX; + constexpr static bool scale_on_logits = true; + + public: + AttentionImpl() : current_q_head_num_(0) { + // Use all columns in AMX tiles + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + ~AttentionImpl() { _tile_release(); } + + template