From f5dfa0753163530117b4766c4e79e8cb2dc7066e Mon Sep 17 00:00:00 2001 From: noiji <52301388+noiji@users.noreply.github.com> Date: Mon, 30 Jun 2025 18:21:56 +0900 Subject: [PATCH 001/195] [Bugfix] Skip loading extra parameters for modelopt Qwen3 MoE model (#19598) Signed-off-by: noiji <> --- vllm/model_executor/models/qwen3_moe.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 417d7b22088bf..90a28192eccbc 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -386,6 +386,11 @@ class Qwen3MoeModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -410,10 +415,11 @@ class Qwen3MoeModel(nn.Module): if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -433,9 +439,9 @@ class Qwen3MoeModel(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -446,9 +452,9 @@ class Qwen3MoeModel(nn.Module): expert_id=expert_id) break else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From e936e401debe7fba64d6462666d7dc632bc76357 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 30 Jun 2025 18:16:16 +0800 Subject: [PATCH 002/195] [Bugfix] Fix processor initialization in transformers 4.53.0 (#20244) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/inputs/registry.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 66e78833f52af..fc6e190e54806 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,7 +5,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch +from packaging.version import Version from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers import __version__ as TRANSFORMERS_VERSION from typing_extensions import TypeVar from vllm.jsontree import JSONTree, json_map_leaves @@ -128,9 +130,13 @@ class InputProcessingContext(InputContext): /, **kwargs: object, ) -> _P: + # Transformers 4.53.0 has issue with passing tokenizer to + # initialize processor. We disable it for this version. + # See: https://github.com/vllm-project/vllm/issues/20224 + if Version(TRANSFORMERS_VERSION) != Version("4.53.0"): + kwargs["tokenizer"] = self.tokenizer return super().get_hf_processor( typ, - tokenizer=self.tokenizer, **kwargs, ) From 8fe7fc863481a7d48c6f5bcc7bb40b2c7ebb5476 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Jun 2025 18:22:09 +0800 Subject: [PATCH 003/195] [Quantization] Improve BitsAndBytesModelLoader (#20242) Signed-off-by: Jee Jee Li --- .../model_loader/bitsandbytes_loader.py | 123 ++++++++++-------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 09857ef297f0a..0c46d170e88d5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping, set_weight_attrs) from vllm.platforms import current_platform +# yapf conflicts with isort for this block + logger = init_logger(__name__) @@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader): self.unsharded_weights_modules: list[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: list[str] = [] + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: dict[str, list[int]] = {} # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] # mapping weight names from transformers to vllm. self.weight_mapper: Callable = lambda name: name + self.pre_quant: bool = False + self.load_8bit: bool = False + self.is_pool_model: bool = False def _get_weight_files( self, @@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - def _maybe_pool_model(module_name:str): + + def _maybe_pool_model(module_name: str): # For pool model, we need to add the prefix `model.` # for the weight name if possible. if self.is_pool_model and self.target_modules[0]. \ startswith("model.") and not module_name.startswith( "model."): - return "model."+module_name + return "model." + module_name return module_name @@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) - mapped_name=_maybe_pool_model(mapped_name) - + mapped_name = _maybe_pool_model(mapped_name) yield org_name, mapped_name, param @@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader): self, model_name_or_path: str, revision: Optional[str], - pre_quant: bool, - load_8bit: bool, ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict: dict[str, Any] = {} - if pre_quant: - if load_8bit: + if self.pre_quant: + if self.load_8bit: return self._quantized_8bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict @@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): yield org_weight_name, processed_weight def _get_bnb_target_modules(self, model: nn.Module) -> None: - + """ + Identify and collect all modules that support BitsAndBytes + quantization. + """ for name, module in model.named_modules(): - if (isinstance(module, LinearBase) and - hasattr(module.quant_method, "quant_config")): + if (isinstance(module, LinearBase) + and hasattr(module.quant_method, "quant_config")): if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info @@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" - def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - if not hasattr(model, "load_weights"): - raise AttributeError( - "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") - - if not hasattr(model, "packed_modules_mapping"): - raise AttributeError( - f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") - self.is_pool_model=is_pooling_model(model) - - self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) - - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - - # Modules whose weights might have fused on disk - # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: dict[str, list[int]] = {} - self._get_bnb_target_modules(model) + def _classify_module_sharding(self, model: nn.Module): + """ + Categorize modules based on their weight sharding requirements + for tensor parallelism. + """ for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new @@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader): elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) - self.model_type = type(model).__name__ + def _verify_model_compatibility(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Verify that the model is compatible with BitsAndBytes quantization. + """ + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") quant_config = getattr(model_config.hf_config, "quantization_config", None) - - pre_quant = False if quant_config is not None: quant_method = quant_config.get("quant_method") if quant_method == "bitsandbytes": - pre_quant = True + self.pre_quant = True else: raise ValueError( f"BitsAndBytes loader does not support {quant_method} " @@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader): # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. - if pre_quant and get_tensor_model_parallel_world_size() > 1: + if self.pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with tensor parallelism is not " "supported. Please try with pipeline parallelism.") + if self.pre_quant: + self.load_8bit = quant_config.get("load_in_8bit", False) - load_8bit = False - if pre_quant: - load_8bit = quant_config.get("load_in_8bit", False) + def _initialize_loader_state(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Initialize the loader's internal state based on the model and + configuration. + """ + self.is_pool_model = is_pooling_model(model) + self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + self._get_bnb_target_modules(model) + self._classify_module_sharding(model) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision, - pre_quant, load_8bit)) - + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(qweight_iterator) # Some models may have weights loading tracker unimplemented. @@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): offsets = torch.tensor(offsets).cpu() set_weight_attrs(param, {"bnb_shard_offsets": offsets}) - if load_8bit: + if self.load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) torch.cuda.empty_cache() + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) From 3ee56e26be4cfddc17f7d2e5f38f15ab74ede1c2 Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Mon, 30 Jun 2025 19:20:51 +0800 Subject: [PATCH 004/195] [Docs] Fix 1-2-3 list in v1/prefix_caching.md (#20243) Signed-off-by: windsonsea --- docs/design/v1/prefix_caching.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/design/v1/prefix_caching.md b/docs/design/v1/prefix_caching.md index e87e4c6a48b73..2d3c8412894a6 100644 --- a/docs/design/v1/prefix_caching.md +++ b/docs/design/v1/prefix_caching.md @@ -117,8 +117,8 @@ There are two design points to highlight: 1. We allocate all KVCacheBlock when initializing the KV cache manager to be a block pool. This avoids Python object creation overheads and can easily track all blocks all the time. 2. We introduce doubly linked list pointers directly in the KVCacheBlock, so that we could construct a free queue directly. This gives us two benefits: - 1. We could have O(1) complexity moving elements in the middle to the tail. - 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements. + 1. We could have O(1) complexity moving elements in the middle to the tail. + 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements. As a result, we will have the following components when the KV cache manager is initialized: @@ -135,19 +135,19 @@ As a result, we will have the following components when the KV cache manager is **New request:** Workflow for the scheduler to schedule a new request with KV cache block allocation: -1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up Cache Blocks. +1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up cache blocks. 2. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps: - 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. - 2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration. - 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. - 4. If an allocated block is already full of tokens, we immediately add it to the Cache Block, so that the block can be reused by other requests in the same batch. + 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. + 2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration. + 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. + 4. If an allocated block is already full of tokens, we immediately add it to the cache block, so that the block can be reused by other requests in the same batch. **Running request:** Workflow for the scheduler to schedule a running request with KV cache block allocation: 1. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps: - 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. - 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. - 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the Cache Block to cache it. + 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. + 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. + 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the cache block to cache it. **Duplicated blocks** Assuming block size is 4 and you send a request (Request 1\) with prompt ABCDEF and decoding length 3: @@ -199,7 +199,7 @@ When a request is finished, we free all its blocks if no other requests are usin When the head block (least recently used block) of the free queue is cached, we have to evict the block to prevent it from being used by other requests. Specifically, eviction involves the following steps: 1. Pop the block from the head of the free queue. This is the LRU block to be evicted. -2. Remove the block ID from the Cache Block. +2. Remove the block ID from the cache block. 3. Remove the block hash. ## Example From 1c50e100a9c5dc439aceb9c4437b262d564baa53 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Mon, 30 Jun 2025 21:24:50 +0800 Subject: [PATCH 005/195] [Bugfix] fix quark ptpc (#20251) Signed-off-by: Haoyang Li Co-authored-by: Haoyang Li <307790822@qq.com> --- .../layers/quantization/quark/quark.py | 6 +--- .../quark/schemes/quark_w8a8_fp8.py | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 6ae5f5c9ad46b..05dff4bae3957 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig): is_fp8_w8a8_supported = self._check_scheme_supported( QuarkW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: - weight_qscheme = cast(str, weight_config.get("qscheme")) - input_static = (input_config is not None and - not cast(bool, input_config.get("is_dynamic"))) - return QuarkW8A8Fp8(qscheme=weight_qscheme, - is_static_input_scheme=input_static) + return QuarkW8A8Fp8(weight_config, input_config) elif self._is_static_tensor_w8a8(weight_config, input_config): weight_qscheme = cast(str, weight_config.get("qscheme")) return QuarkW8A8Int8(qscheme=weight_qscheme, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 47e0a492b23b9..c7bc98184d0eb 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional, cast import torch from torch.nn import Parameter @@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"] class QuarkW8A8Fp8(QuarkScheme): - def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): - self.qscheme = qscheme - self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) + def __init__(self, weight_config: dict[str, Any], + input_config: Optional[dict[str, Any]]): + self.weight_qscheme = cast(str, weight_config.get("qscheme")) + self.is_static_input_scheme: bool = False + self.input_qscheme: Optional[str] = None + if input_config is not None: + self.is_static_input_scheme = not cast( + bool, input_config.get("is_dynamic")) + self.input_qscheme = cast(str, input_config.get("qscheme")) + self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ + and self.input_qscheme == "per_channel") + self.fp8_linear = Fp8LinearOp( + use_per_token_if_dynamic=self.use_per_token_if_dynamic) self.out_dtype = torch.get_default_dtype() @classmethod @@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme): # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor - if self.qscheme == "per_tensor": + if self.weight_qscheme == "per_tensor": if current_platform.is_rocm(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( @@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme): layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. - elif self.qscheme == "per_channel": + elif self.weight_qscheme == "per_channel": weight = layer.weight if current_platform.is_fp8_fnuz(): @@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme): requires_grad=False) else: weight_scale = layer.weight_scale.data - + if self.use_per_token_if_dynamic: + weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: - raise ValueError(f"Unknown quantization scheme {self.qscheme}") + raise ValueError( + f"Unknown quantization scheme {self.weight_qscheme}") # INPUT SCALE if self.is_static_input_scheme: @@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme): # WEIGHT SCALE # TODO: update create_xxx_parameter functions to return # the newly added parameters - if self.qscheme == "per_channel": + if self.weight_qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, weight_loader=weight_loader) else: - assert self.qscheme == "per_tensor" + assert self.weight_qscheme == "per_tensor" weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) From 2062c0723d38a8f4a8a7565b61a99e8c81b5cacd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:13:50 -0700 Subject: [PATCH 006/195] [Spec Decode] Refactor spec decoding into a separate function (#20238) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 93 +++++++++++++++++++----------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 290b9a44a80e2..e063e44dabfa1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, aux_hidden_states = model_output else: hidden_states = model_output + aux_hidden_states = None + # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches @@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None - elif self.speculative_config.method == "ngram": + else: + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + attn_metadata, + ) + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + self.eplb_step() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, + num_nans_in_logits=num_nans_in_logits, + ) + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + attn_metadata: dict[str, Any], + ) -> list[list[int]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self.propose_ngram_draft_token_ids( + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) - if max_gen_len == 1: + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, - valid_sampled_token_ids): + sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 - - indices = torch.tensor(indices, - device=sample_hidden_states.device) + indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] spec_token_ids = self.drafter.propose( @@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): + for i, token_ids in enumerate(sampled_token_ids): if token_ids: # Common case. next_token_id = token_ids[-1] @@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens_tensor = async_tensor_h2d( @@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens, ) target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - - self.eplb_step() - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, - ) + return spec_token_ids def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: @@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def generate_draft_token_ids( + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] From 2965c99c86b460ee819e4805764d769c7b7d3d8e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:28:13 -0700 Subject: [PATCH 007/195] [Spec Decode] Clean up spec decode example (#20240) Signed-off-by: Woosuk Kwon --- examples/offline_inference/eagle.py | 144 ---------------------- examples/offline_inference/spec_decode.py | 40 +++--- 2 files changed, 21 insertions(+), 163 deletions(-) delete mode 100644 examples/offline_inference/eagle.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py deleted file mode 100644 index f4193fdb8bd38..0000000000000 --- a/examples/offline_inference/eagle.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import json -import os - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.v1.metrics.reader import Counter, Vector - - -def load_prompts(dataset_path, num_prompts): - if os.path.exists(dataset_path): - prompts = [] - try: - with open(dataset_path) as f: - for line in f: - data = json.loads(line) - prompts.append(data["turns"][0]) - except Exception as e: - print(f"Error reading dataset: {e}") - return [] - else: - prompts = ["The future of AI is", "The president of the United States is"] - - return prompts[:num_prompts] - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset", - type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", - ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["eagle", "eagle3"] - ) - parser.add_argument("--max_num_seqs", type=int, default=8) - parser.add_argument("--num_prompts", type=int, default=80) - parser.add_argument("--num_spec_tokens", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft_tp", type=int, default=1) - parser.add_argument("--enforce_eager", action="store_true") - parser.add_argument("--enable_chunked_prefill", action="store_true") - parser.add_argument("--max_num_batched_tokens", type=int, default=2048) - parser.add_argument("--temp", type=float, default=0) - return parser.parse_args() - - -def main(): - args = parse_args() - - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - - if args.method == "eagle": - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == "eagle3": - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - else: - raise ValueError(f"unknown method: {args.method}") - - max_model_len = 2048 - - tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = load_prompts(args.dataset, args.num_prompts) - - prompt_ids = [ - tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], add_generation_prompt=True - ) - for prompt in prompts - ] - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - max_num_batched_tokens=args.max_num_batched_tokens, - enforce_eager=args.enforce_eager, - max_model_len=max_model_len, - max_num_seqs=args.max_num_seqs, - gpu_memory_utilization=0.8, - speculative_config={ - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": max_model_len, - }, - disable_log_stats=False, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) - - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) - - # print the generated text - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - try: - metrics = llm.get_metrics() - except AssertionError: - print("Metrics are not supported in the V0 engine.") - return - - num_drafts = num_accepted = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") - print("-" * 50) - - # print acceptance at each token position - for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") - - -if __name__ == "__main__": - print( - "[WARNING] Use examples/offline_inference/spec_decode.py" - " instead of this script." - ) - main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 6fa68d2ecee1d..90d103e5cb05d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -16,24 +16,17 @@ def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) parser.add_argument( - "--dataset", + "--method", type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"] - ) - parser.add_argument("--max-num-seqs", type=int, default=8) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft-tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-num-batched-tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -41,7 +34,6 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--max-model-len", type=int, default=2048) return parser.parse_args() @@ -71,8 +63,6 @@ def main(): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": args.max_model_len, } elif args.method == "ngram": speculative_config = { @@ -80,7 +70,6 @@ def main(): "num_speculative_tokens": args.num_spec_tokens, "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, - "max_model_len": args.max_model_len, } else: raise ValueError(f"unknown method: {args.method}") @@ -92,7 +81,6 @@ def main(): enable_chunked_prefill=args.enable_chunked_prefill, max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, @@ -116,27 +104,41 @@ def main(): print("Metrics are not supported in the V0 engine.") return - num_drafts = num_accepted = 0 + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 acceptance_counts = [0] * args.num_spec_tokens for metric in metrics: if metric.name == "vllm:spec_decode_num_drafts": assert isinstance(metric, Counter) num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens": assert isinstance(metric, Counter) - num_accepted += metric.value + num_accepted_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": assert isinstance(metric, Vector) for pos in range(len(metric.values)): acceptance_counts[pos] += metric.values[pos] print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") print("-" * 50) # print acceptance at each token position for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") if __name__ == "__main__": From 2863befce359ee1a82afe02d1953252866aa3e96 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 09:07:50 -0700 Subject: [PATCH 008/195] [Optimization] Use Shared `CachedRequestData` Instance Across All Requests (#20232) Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 130 +++++++++--------- .../unit/test_remote_decode_lifecycle.py | 4 +- .../unit/test_remote_prefill_lifecycle.py | 12 +- tests/v1/kv_connector/unit/utils.py | 1 - tests/v1/tpu/worker/test_tpu_model_runner.py | 22 +-- tests/v1/worker/test_gpu_model_runner.py | 22 +-- .../kv_connector/v1/p2p/p2p_nccl_connector.py | 43 +++--- .../v1/shared_storage_connector.py | 19 ++- vllm/v1/core/sched/output.py | 34 +++-- vllm/v1/core/sched/scheduler.py | 106 ++++++-------- vllm/v1/worker/gpu_model_runner.py | 34 ++--- vllm/v1/worker/tpu_model_runner.py | 24 ++-- 12 files changed, 220 insertions(+), 231 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8994816a3017c..652a556659fe3 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): @@ -225,7 +225,7 @@ def test_schedule_multimodal_requests(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) @@ -259,7 +259,7 @@ def test_schedule_partial_requests(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 assert scheduler.max_num_encoder_input_tokens == 1024 @@ -295,7 +295,7 @@ def test_schedule_partial_requests(): output = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output.scheduled_new_reqs) == 0 - assert len(output.scheduled_cached_reqs) == 2 + assert output.scheduled_cached_reqs.num_reqs == 2 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[1].request_id] == 700 @@ -319,7 +319,7 @@ def test_no_mm_input_chunking(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # We want to only see the 400 text tokens at the start scheduled assert output.num_scheduled_tokens[requests[0].request_id] == 400 @@ -342,7 +342,7 @@ def test_no_mm_input_chunking(): output = scheduler.schedule() assert len(scheduler.running) == 1 assert len(output.scheduled_new_reqs) == 0 - assert len(output.scheduled_cached_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 800 @@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # The first request is scheduled partially - 400. @@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output1 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output1.scheduled_new_reqs) == 0 - assert len(output1.scheduled_cached_reqs) == 3 + assert output1.scheduled_cached_reqs.num_reqs == 3 assert len(output1.finished_req_ids) == 0 assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400 @@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output2 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output2.scheduled_new_reqs) == 0 - assert len(output2.scheduled_cached_reqs) == 3 + assert output2.scheduled_cached_reqs.num_reqs == 3 assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 @@ -449,23 +449,24 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -501,23 +502,25 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -551,23 +554,25 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -603,7 +608,7 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, @@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 - assert len(scheduler._cached_reqs_data) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index ff36a281c413d..12a71d97e8d29 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -66,7 +66,7 @@ def test_basic_lifecycle(): assert len(scheduler_output.finished_req_ids) == 1 assert request_id in scheduler_output.finished_req_ids assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler.finished_req_ids) == 0 # (2b): execute_model() @@ -81,7 +81,7 @@ def test_basic_lifecycle(): assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler.finished_req_ids) == 0 # (3b): execute_model() diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index a1156306dc4bf..f89970bf2c807 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -36,7 +36,7 @@ def test_basic_lifecycle(): # Nothing running and empty scheduler output. assert len(scheduler.running) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler_output.num_scheduled_tokens) == 0 assert scheduler_output.total_num_scheduled_tokens == 0 @@ -158,7 +158,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 1 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 model_runner_output = create_model_runner_output( [request_local_a, request_local_b]) @@ -169,7 +169,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( reqs=[request_local_a, request_local_b]) @@ -177,14 +177,14 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 # STEP 4: KVs arrive. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( [request_local_a, request_local_b], @@ -196,7 +196,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 3 assert len(scheduler.waiting) == 0 assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( [request_local_a, request_local_b, request_remote]) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 61f59f35f75b9..983d900606fc9 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_recving_kv_req_ids) == 0 - assert len(scheduler._cached_reqs_data) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 25839d0897a4c..40db0b2afe0d9 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: return SchedulerOutput( scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, @@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner): # finish req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner): # unschedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner): # resume req cached_req_data = CachedRequestData( - req_id=req_id, - resumed_from_preemption=False, - new_token_ids=[], - new_block_ids=([], ), - num_computed_tokens=0, + req_ids=[req_id], + resumed_from_preemption=[False], + new_token_ids=[[]], + new_block_ids=[([], )], + num_computed_tokens=[0], ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[cached_req_data], + scheduled_cached_reqs=cached_req_data, num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner): # schedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner): # unschedule req_1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 583a88d8e6ec6..c739b23b90dc8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: return SchedulerOutput( scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, @@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner): # finish req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner): # unschedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner): # resume req cached_req_data = CachedRequestData( - req_id=req_id, - resumed_from_preemption=False, - new_token_ids=[], - new_block_ids=([], ), - num_computed_tokens=0, + req_ids=[req_id], + resumed_from_preemption=[False], + new_token_ids=[[]], + new_block_ids=([[0]], ), + num_computed_tokens=[0], ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[cached_req_data], + scheduled_cached_reqs=cached_req_data, num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner): # schedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner): # unschedule req_1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index a47deaf91272e..2f870971ded70 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1): block_size=self._block_size) self._requests_need_load.pop(new_req.req_id) - for cached_req in scheduler_output.scheduled_cached_reqs: + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + if self.is_producer: num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[cached_req.req_id] - num_tokens = (num_scheduled_tokens + - cached_req.num_computed_tokens) - assert cached_req.req_id in self.chunked_prefill - block_ids = cached_req.new_block_ids[0] - if not cached_req.resumed_from_preemption: - block_ids = (self.chunked_prefill[cached_req.req_id][0] + - block_ids) - prompt_token_ids = self.chunked_prefill[cached_req.req_id][1] + scheduler_output.num_scheduled_tokens)[req_id] + num_tokens = (num_scheduled_tokens + num_computed_tokens) + assert req_id in self.chunked_prefill + block_ids = new_block_ids[0] + if not resumed_from_preemption: + block_ids = (self.chunked_prefill[req_id][0] + block_ids) + prompt_token_ids = self.chunked_prefill[req_id][1] # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): - self.chunked_prefill[cached_req.req_id] = ( - block_ids, prompt_token_ids) + self.chunked_prefill[req_id] = (block_ids, + prompt_token_ids) continue # the request's prompt is all prefilled finally - meta.add_request(request_id=cached_req.req_id, + meta.add_request(request_id=req_id, token_ids=prompt_token_ids, block_ids=block_ids, block_size=self._block_size) - self.chunked_prefill.pop(cached_req.req_id, None) + self.chunked_prefill.pop(req_id, None) continue # NOTE(rob): here we rely on the resumed requests being # the first N requests in the list scheduled_cache_reqs. - if not cached_req.resumed_from_preemption: + if not resumed_from_preemption: break - if cached_req.req_id in self._requests_need_load: - request, _ = self._requests_need_load.pop(cached_req.req_id) - total_tokens = cached_req.num_computed_tokens + 1 + if req_id in self._requests_need_load: + request, _ = self._requests_need_load.pop(req_id) + total_tokens = num_computed_tokens + 1 token_ids = request.all_token_ids[:total_tokens] # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = new_block_ids[0] - meta.add_request(request_id=cached_req.req_id, + meta.add_request(request_id=req_id, token_ids=token_ids, block_ids=block_ids, block_size=self._block_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f86b92692a0e5..0bceee19f873d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1): block_size=self._block_size, is_store=True) - for cached_req in scheduler_output.scheduled_cached_reqs: + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_token_ids = cached_reqs.new_token_ids[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + # NOTE(rob): here we rely on the resumed requests being # the first N requests in the list scheduled_cache_reqs. - if not cached_req.resumed_from_preemption: + if not resumed_from_preemption: break - if cached_req.req_id in self._requests_need_load: + if req_id in self._requests_need_load: # NOTE(rob): cached_req_data does not have the full # list of token ids (only new tokens). So we look it # up in the actual request object. - request = self._requests_need_load[cached_req.req_id] - total_tokens = (len(cached_req.new_token_ids) + - cached_req.num_computed_tokens) + request = self._requests_need_load[req_id] + total_tokens = (len(new_token_ids) + num_computed_tokens) token_ids = request.all_token_ids[:total_tokens] # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = new_block_ids[0] meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6f31031a1086e..efc5b3012ec2f 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -83,29 +83,27 @@ class NewRequestData: @dataclass class CachedRequestData: - req_id: str + req_ids: list[str] # If resumed_from_preemption is False, new_block_ids will be appended to # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: bool - new_token_ids: list[int] - new_block_ids: tuple[list[int], ...] - num_computed_tokens: int + resumed_from_preemption: list[bool] + new_token_ids: list[list[int]] + new_block_ids: list[tuple[list[int], ...]] + num_computed_tokens: list[int] + + @property + def num_reqs(self) -> int: + return len(self.req_ids) @classmethod - def from_request( - cls, - request: Request, - resumed_from_preemption: bool, - new_token_ids: list[int], - new_block_ids: tuple[list[int], ...], - ) -> CachedRequestData: + def make_empty(cls) -> CachedRequestData: return cls( - req_id=request.request_id, - resumed_from_preemption=resumed_from_preemption, - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=request.num_computed_tokens, + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], ) @@ -119,7 +117,7 @@ class SchedulerOutput: # list of the requests that have been scheduled before. # Since the request's data is already cached in the worker processes, # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: list[CachedRequestData] + scheduled_cached_reqs: CachedRequestData # req_id -> num_scheduled_tokens # Number of tokens scheduled for each request. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b0844a5660b..20a40d74f3118 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -3,8 +3,9 @@ from __future__ import annotations +import itertools import time -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable from typing import Any, Optional, Union @@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface): # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - # Request id -> deque of CachedRequestData - self._cached_reqs_data: dict[ - str, deque[CachedRequestData]] = defaultdict(deque) - # Encoder-related. # Calculate encoder cache size if applicable # NOTE: For now we use the same budget for both compute and space. @@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface): req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] - resumed_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=True, - ) for req in scheduled_resumed_reqs - ] - running_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=False, - ) for req in scheduled_running_reqs - ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_block_ids, + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + scheduled_cached_reqs=cached_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, @@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface): def _make_cached_request_data( self, - request: Request, - num_scheduled_tokens: int, - num_scheduled_spec_tokens: int, - new_block_ids: tuple[list[int], ...], - resumed_from_preemption: bool, + running_reqs: list[Request], + resumed_reqs: list[Request], + num_scheduled_tokens: dict[str, int], + spec_decode_tokens: dict[str, list[int]], + req_to_new_block_ids: dict[str, tuple[list[int], ...]], ) -> CachedRequestData: - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - num_computed_tokens = request.num_computed_tokens - num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens - new_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_ids: list[str] = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens: list[int] = [] - req_data_queue = self._cached_reqs_data.get(request.request_id) - if req_data_queue: - req_data = req_data_queue.popleft() - req_data.resumed_from_preemption = resumed_from_preemption - req_data.new_token_ids = new_token_ids - req_data.new_block_ids = new_block_ids - req_data.num_computed_tokens = num_computed_tokens - else: - # No cached request data, or all cached request data has been - # used by the scheduled requests. - req_data = CachedRequestData.from_request(request, - resumed_from_preemption, - new_token_ids, - new_block_ids) - return req_data + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id + req_ids.append(req_id) + num_tokens = (num_scheduled_tokens[req_id] - + len(spec_decode_tokens.get(req_id, ()))) + token_ids = req.all_token_ids[req.num_computed_tokens:req. + num_computed_tokens + num_tokens] + new_token_ids.append(token_ids) + new_block_ids.append(req_to_new_block_ids[req_id]) + num_computed_tokens.append(req.num_computed_tokens) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + + return CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) def _try_schedule_encoder_inputs( self, @@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface): if not stopped: new_running.append(request) + self.running = new_running # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) - # Return the cached request data to the queue so they can be reused. - for req_data in scheduler_output.scheduled_cached_reqs: - # NOTE(rob): since we free stopped reqs above, adding stopped reqs - # to _cached_reqs_data will cause a memory leak. - if req_data.req_id not in self.finished_req_ids: - self._cached_reqs_data[req_data.req_id].append(req_data) - - self.running = new_running - # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = { @@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface): self._free_request(request) def _free_request(self, request: Request) -> Optional[dict[str, Any]]: - assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) request_id = request.request_id - self._cached_reqs_data.pop(request_id, None) self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) @@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface): def _free_blocks(self, request: Request): assert request.is_finished() - assert request.request_id not in self._cached_reqs_data self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e063e44dabfa1..29d39de212f88 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_token_ids = req_data.new_token_ids[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - + num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. - req_state.output_token_ids.append(req_data.new_token_ids[-1]) + req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) + new_token_ids[-num_new_tokens:]) # Update the block IDs. - if not req_data.resumed_from_preemption: + if not resumed_from_preemption: # Append the new blocks to the existing block IDs. - for block_ids, new_block_ids in zip(req_state.block_ids, - req_data.new_block_ids): - block_ids.extend(new_block_ids) + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids + req_state.block_ids = new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(req_data.new_token_ids) + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = req_data.new_token_ids + req_index, start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bc334419c4cec..0cc218bdb646f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens - if not req_data.resumed_from_preemption: + req_state.num_computed_tokens = num_computed_tokens + if not resumed_from_preemption: # Append the new blocks to the existing block IDs. - for block_ids, new_block_ids in zip(req_state.block_ids, - req_data.new_block_ids): - block_ids.extend(new_block_ids) + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids + req_state.block_ids = new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + num_computed_tokens) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. From 551ef1631a98d60fe9e82f0282e49c4a59a7887b Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:26:42 -0400 Subject: [PATCH 009/195] [Unit Test] Add unit test for deep gemm (#20090) Signed-off-by: yewentao256 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/kernels/moe/test_deepgemm.py | 225 +++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 tests/kernels/moe/test_deepgemm.py diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py new file mode 100644 index 0000000000000..5d2690904cea2 --- /dev/null +++ b/tests/kernels/moe/test_deepgemm.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Unit-test DeepGEMM FP8 kernels (no DeepEP). +Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. +""" + +import importlib +import math + +import pytest +import torch + +# vLLM fused-expert reference (Triton fallback + DeepGEMM option) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import cdiv + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + +if has_deep_gemm: + import deep_gemm + BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() + BLOCK_SIZE = [BLOCK_M, BLOCK_M] + +requires_deep_gemm = pytest.mark.skipif( + not has_deep_gemm, + reason="Requires deep_gemm kernels", +) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +): + """ + Generate (w1, w2) expert weights and their per-block scale tensors + in FP8 block-quantized format. + + w1 shape: (E, 2N, K) + w2 shape: (E, K, N) + """ + dtype = torch.bfloat16 + fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( + torch.float8_e4m3fn).min + + # bf16 reference weights + w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 + w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10 + w1_bf16.clamp_(fp8_min, fp8_max) + w2_bf16.clamp_(fp8_min, fp8_max) + + block_n, block_k = block_size + n_tiles_w1 = math.ceil((2 * n) / block_n) + k_tiles_w1 = math.ceil(k / block_k) + n_tiles_w2 = math.ceil(k / block_n) + k_tiles_w2 = math.ceil(n / block_k) + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + w1_s = torch.empty(e, + n_tiles_w1, + k_tiles_w1, + device="cuda", + dtype=torch.float32) + w2_s = torch.empty(e, + n_tiles_w2, + k_tiles_w2, + device="cuda", + dtype=torch.float32) + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + return w1, w2, w1_s, w2_s + + +def run_single_case(m, n, k, topk, num_experts, block_size): + """ + Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == + Triton baseline within tolerance. + """ + tokens_bf16 = torch.randn( + m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) + + # expert weight tensors + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, + block_size) + + router_logits = torch.randn(m, + num_experts, + device="cuda", + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + # triton referrence + out_triton = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=False, + ) + + # DeepGemm + out_deepgemm = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=True, + ) + + base = out_triton.abs().mean() + atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 + rtol = 0.05 + # ----- Compare ----- + torch.testing.assert_close( + out_deepgemm.to(torch.float32), + out_triton.to(torch.float32), + rtol=rtol, + atol=float(atol), + ) + + +# Note: W1 has shape (E, 2N, K), so N = 512 +# can trigger the deepgemm path. +MNKs = [ + (1024, 512, 128), + (1024, 512, 512), + (2048, 512, 512), + (512, 1024, 1024), + (512, 2048, 2048), + (4096, 4096, 1024), +] + +TOPKS = [2, 6] +NUM_EXPERTS = [32] + + +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("topk", TOPKS) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@requires_deep_gemm +def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_DEEP_GEMM", "1") + + _fused_moe_mod = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe") + + call_counter = {"cnt": 0} + + orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 + + def _spy_deep_gemm_moe_fp8(*args, **kwargs): + call_counter["cnt"] += 1 + return orig_fn(*args, **kwargs) + + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", + _spy_deep_gemm_moe_fp8) + + m, n, k = mnk + + if topk > num_experts: + pytest.skip(f"topk={topk} > num_experts={num_experts}") + + run_single_case( + m=m, + n=n, + k=k, + topk=topk, + num_experts=num_experts, + block_size=BLOCK_SIZE, + ) + + # ensure that the DeepGEMM path was indeed taken. + assert call_counter["cnt"] == 1, \ + f"DeepGEMM path was not executed during the test. " \ + f"Call counter: {call_counter['cnt']}" From d8cf819a9a337a578b7dfc7642617921cc468c17 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 30 Jun 2025 13:26:49 -0400 Subject: [PATCH 010/195] [Core] [Bugfix] [Multimodal] Fix multimodal profiling and generation for SFT/PTQed models (#20058) Signed-off-by: Kyle Sayers --- docs/contributing/model/multimodal.md | 7 +++ tests/multimodal/test_processing.py | 1 + vllm/entrypoints/llm.py | 8 +++ vllm/entrypoints/utils.py | 4 ++ vllm/inputs/preprocess.py | 29 +++++++--- vllm/model_executor/models/aya_vision.py | 2 + vllm/model_executor/models/blip2.py | 2 + vllm/model_executor/models/chameleon.py | 2 + vllm/model_executor/models/deepseek_vl2.py | 6 ++- vllm/model_executor/models/florence2.py | 4 +- vllm/model_executor/models/fuyu.py | 2 + vllm/model_executor/models/gemma3_mm.py | 2 + vllm/model_executor/models/glm4v.py | 1 + vllm/model_executor/models/granite_speech.py | 2 + vllm/model_executor/models/h2ovl.py | 3 ++ vllm/model_executor/models/idefics3.py | 2 + vllm/model_executor/models/internvl.py | 5 +- vllm/model_executor/models/llava.py | 5 +- vllm/model_executor/models/llava_onevision.py | 7 +++ vllm/model_executor/models/minicpmo.py | 10 ++-- vllm/model_executor/models/minicpmv.py | 18 +++++-- vllm/model_executor/models/minimax_vl_01.py | 2 + vllm/model_executor/models/mistral3.py | 2 + vllm/model_executor/models/mllama.py | 6 ++- vllm/model_executor/models/mllama4.py | 2 + vllm/model_executor/models/ovis.py | 2 + vllm/model_executor/models/paligemma.py | 5 +- vllm/model_executor/models/phi3v.py | 2 + vllm/model_executor/models/phi4mm.py | 3 +- vllm/model_executor/models/pixtral.py | 7 ++- .../models/prithvi_geospatial_mae.py | 1 + .../models/qwen2_5_omni_thinker.py | 7 +++ vllm/model_executor/models/qwen2_audio.py | 2 + vllm/model_executor/models/qwen2_vl.py | 4 +- vllm/model_executor/models/qwen_vl.py | 3 ++ vllm/model_executor/models/skyworkr1v.py | 2 + vllm/model_executor/models/ultravox.py | 6 +++ vllm/model_executor/models/whisper.py | 4 +- vllm/multimodal/processing.py | 54 +++++++++++++++---- vllm/multimodal/profiling.py | 7 ++- vllm/utils.py | 2 + 41 files changed, 207 insertions(+), 38 deletions(-) diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 6ff2abbae6329..670d747b9ee7d 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -538,11 +538,13 @@ return a schema of the tensors outputted by the HF processor that are related to prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_patches = processed_outputs.get("image_patches") @@ -566,6 +568,11 @@ return a schema of the tensors outputted by the HF processor that are related to Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling for text-only inputs to prevent unnecessary warnings from HF processor. + !!! note + The `_call_hf_processor` method specifies both `mm_kwargs` and `tok_kwargs` for + processing. `mm_kwargs` is used to both initialize and call the huggingface + processor, whereas `tok_kwargs` is only used to call the huggingface processor. + This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows: ```python diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 8b52911c6ccf3..2f97475f121a0 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1086,6 +1086,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs): prompt="", mm_data={}, mm_kwargs=call_kwargs, + tok_kwargs={}, ) assert out_kwargs == expected_kwargs diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63967e4d2d4bc..f0404e0bc6eac 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -481,6 +481,13 @@ class LLM: # Use default sampling params. sampling_params = self.get_default_sampling_params() + tokenization_kwargs: dict[str, Any] = {} + truncate_prompt_tokens = None + if isinstance(sampling_params, SamplingParams): + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + self._validate_and_add_requests( prompts=parsed_prompts, params=sampling_params, @@ -488,6 +495,7 @@ class LLM: lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, guided_options=guided_options_request, + tokenization_kwargs=tokenization_kwargs, priority=priority, ) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 16ba2b4531acf..50f810afb8ccd 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -171,6 +171,10 @@ def _validate_truncation_size( tokenization_kwargs["truncation"] = True tokenization_kwargs["max_length"] = truncate_prompt_tokens + else: + if tokenization_kwargs is not None: + tokenization_kwargs["truncation"] = False + return truncate_prompt_tokens diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a13e563f34a14..deda9bc23dafe 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -265,7 +265,8 @@ class InputPreprocessor: prompt: Union[str, list[int]], mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], - lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -280,15 +281,19 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) async def _process_multimodal_async( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], - lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -302,8 +307,11 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) def _process_embeds( self, @@ -338,6 +346,7 @@ class InputPreprocessor: def _process_tokens( self, parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: @@ -350,6 +359,7 @@ class InputPreprocessor: prompt_token_ids, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -367,6 +377,7 @@ class InputPreprocessor: async def _process_tokens_async( self, parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: @@ -379,6 +390,7 @@ class InputPreprocessor: prompt_token_ids, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -408,6 +420,7 @@ class InputPreprocessor: prompt_text, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -442,6 +455,7 @@ class InputPreprocessor: prompt_text, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -860,7 +874,8 @@ class InputPreprocessor: "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder - return self._process_encoder_decoder_prompt(prompt) + return self._process_encoder_decoder_prompt( + prompt, tokenization_kwargs) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index a48631ad709f7..38daf995b8ca3 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -185,11 +185,13 @@ class AyaVisionMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) image_processor = hf_processor.image_processor diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 3c3955161daaa..ecc12fa8d3727 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -454,6 +454,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # HF processor always adds placeholders even when there's no image @@ -465,6 +466,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index d538ba09c65cf..06e33ad7737e3 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -107,6 +107,7 @@ class ChameleonMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: prompt_ids = self.info.get_tokenizer().encode(prompt) @@ -117,6 +118,7 @@ class ChameleonMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _apply_hf_processor_tokens_only( diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index da5452409d2f9..cdda9fb5a7490 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -204,12 +204,13 @@ class DeepseekVL2MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: processed_outputs = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(prompt=prompt, **mm_data), - mm_kwargs, + dict(**mm_kwargs, **tok_kwargs), ) pixel_values = processed_outputs["pixel_values"] # split pixel values into patches corresponding to each image @@ -278,6 +279,7 @@ class DeepseekVL2MultiModalProcessor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -290,6 +292,7 @@ class DeepseekVL2MultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -297,6 +300,7 @@ class DeepseekVL2MultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 425407c19ab5d..bda552721eb23 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -794,6 +794,7 @@ class Florence2MultiModalProcessor( prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False @@ -828,10 +829,11 @@ class Florence2MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) + prompt, mm_data, mm_kwargs, tok_kwargs) else: hf_processor = self.info.get_hf_processor() tokenizer = hf_processor.tokenizer diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7e03982e78e69..b3e055b966b08 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -153,6 +153,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # Avoid warning from HF logger for text-only input @@ -164,6 +165,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_patches = processed_outputs.get("image_patches") diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 3a1c14978b45b..e9c27674b8457 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -259,11 +259,13 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) # HF processor pops the `num_crops` kwarg, which is needed by vLLM diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 70916c45c0e09..95e3fcfc02fab 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -481,6 +481,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index f2dc5708028ba..77fbc4808b4a3 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -141,6 +141,7 @@ class GraniteSpeechMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) @@ -153,6 +154,7 @@ class GraniteSpeechMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) if "audio" in mm_data: diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 8f7f359b75521..467b074f37753 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -490,6 +490,7 @@ class H2OVLMultiModalProcessor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -502,6 +503,7 @@ class H2OVLMultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -509,6 +511,7 @@ class H2OVLMultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index b1d0626217a0a..36cfb5807d7d6 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -326,6 +326,7 @@ class Idefics3MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor if not (images := mm_data.get("images", [])): @@ -337,6 +338,7 @@ class Idefics3MultiModalProcessor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) parsed_images = (self._get_data_parser().parse_mm_data({ diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index bb71177ecad8e..6abe6cd6965c8 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -758,11 +758,13 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) @@ -941,9 +943,10 @@ class InternVLMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs) + mm_kwargs, tok_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs) if self.info.supports_video and ( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 1c35bf5206db7..7a7aefb267181 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -296,11 +296,13 @@ class PixtralHFMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") @@ -797,6 +799,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() @@ -809,7 +812,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c5403762f5390..7ff1026bfc94d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -286,6 +286,7 @@ class LlavaOnevisionMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) @@ -296,6 +297,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) # LLaVA-OneVision processor doesn't support multiple videos @@ -310,6 +312,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=prompt, mm_data={}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) images = mm_data.pop("images", []) @@ -319,6 +322,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=image_token * len(images), mm_data={"images": images}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_outputs = { k: v @@ -334,6 +338,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=video_token, mm_data={"videos": video}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values_videos.append(item_outputs["pixel_values_videos"][0]) @@ -352,11 +357,13 @@ class LlavaOnevisionMultiModalProcessor( prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: base_result = super()._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return base_result and mm_items.get_count("video", strict=False) == 0 diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index ff5959ed196ea..112e0b91d3f17 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -260,6 +260,7 @@ class MiniCPMOMultiModalProcessor( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (audios := mm_data.get("audios")) is None: return {} @@ -276,9 +277,9 @@ class MiniCPMOMultiModalProcessor( prompts=[self.info.audio_pattern] * len(parsed_audios), mm_data={"audios": [[audio] for audio in parsed_audios]}, mm_kwargs={ - **mm_kwargs, - "chunk_input": True, + **mm_kwargs, "chunk_input": True }, + tok_kwargs=tok_kwargs, out_keys={"audio_features", "audio_feature_lens"}, ) @@ -302,10 +303,11 @@ class MiniCPMOMultiModalProcessor( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: return { - **super().process_mm_inputs(mm_data, mm_kwargs), - **self.process_audios(mm_data, mm_kwargs), + **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs), + **self.process_audios(mm_data, mm_kwargs, tok_kwargs), } def _get_prompt_updates( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9dc03c8001824..1dba88be83500 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -534,6 +534,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (images := mm_data.get("images")) is None: return {} @@ -550,6 +551,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompts=[self.info.image_pattern] * len(parsed_images), mm_data={"images": [[image] for image in parsed_images]}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) @@ -563,6 +565,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (videos := mm_data.get("videos")) is None: return {} @@ -586,6 +589,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): "max_slice_nums": self.info.get_video_max_slice_num(), }, + tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) @@ -601,10 +605,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: return { - **self.process_images(mm_data, mm_kwargs), - **self.process_videos(mm_data, mm_kwargs), + **self.process_images(mm_data, mm_kwargs, tok_kwargs), + **self.process_videos(mm_data, mm_kwargs, tok_kwargs), } def _base_call_hf_processor( @@ -612,6 +617,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompts: list[str], mm_data: Mapping[str, Sequence[object]], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], *, out_keys: set[str], ) -> dict[str, NestedTensors]: @@ -621,6 +627,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt=prompts, # type: ignore mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) else: inputs = defaultdict[str, list[torch.Tensor]](list) @@ -633,6 +640,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for k, v in mm_data.items() }, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) for k, v in inputs_one.items(): @@ -646,11 +654,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() - input_ids = torch.tensor([tokenizer.encode(prompt)]) - mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) + input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)]) + mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs) return BatchFeature({ "input_ids": input_ids, @@ -662,6 +671,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 8ce94540e87fe..a125454c0c060 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -113,11 +113,13 @@ class MiniMaxVL01MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 04d6d347cb84f..6840c672a3299 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -228,11 +228,13 @@ class Mistral3MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 1b7e93fafad93..ead5a8e950f0f 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -166,10 +166,11 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -239,6 +240,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if mm_data: @@ -247,7 +249,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] for img in mm_data["images"] ] processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) + prompt, mm_data, mm_kwargs, tok_kwargs) processed_outputs["num_tiles"] = torch.tensor(num_tiles) for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): processed_outputs[k] = processed_outputs[k].squeeze(0) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index a420e757e2194..ea781e18db272 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -574,6 +574,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() @@ -583,6 +584,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) processor = self.info.get_hf_processor(**mm_kwargs) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 6eecd4499fb96..5059b4e69f076 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -335,6 +335,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # Avoid warning from HF logger for text-only input @@ -346,6 +347,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index e1de8cf458780..29ffb62eeafd0 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -121,6 +121,7 @@ class PaliGemmaMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if not mm_data: @@ -131,6 +132,7 @@ class PaliGemmaMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( @@ -191,10 +193,11 @@ class PaliGemmaMultiModalProcessor( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0a7adf91e488f..a084e71f734c2 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -376,11 +376,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) input_ids = processed_outputs["input_ids"] diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 5d1f0775b07fb..3c4162507f03d 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -762,6 +762,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: prompt_ids = self.info.get_tokenizer().encode(prompt) @@ -773,7 +774,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): mm_data['audios'] = [(data, sr) for data in audio_data] processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs) + mm_kwargs, tok_kwargs) num_img_tokens = [ self.info.get_num_image_tokens(image_width=img_size[0], diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 709ac1d9df945..a31c757f7d592 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -237,6 +237,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_images = dummy_mm_data.get("image", []) + tokenization_kwargs = {"truncation": False} request = ChatCompletionRequest(messages=[ UserMessage(content=[ @@ -247,7 +248,9 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens - return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) + return ProcessorInputs(prompt=dummy_tokens, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs) class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] @@ -297,6 +300,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -309,6 +313,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 4fdcae5de644a..f89cf1b5274cf 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -92,6 +92,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_kwargs = {} diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 9497f15984b75..8980f386502fc 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -244,6 +244,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) @@ -258,6 +259,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) input_features = hf_inputs.pop('input_features', None) @@ -453,6 +455,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt: Union[str, list[int]], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: @@ -465,6 +468,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) tokenizer = self.info.get_tokenizer() prompt_ids = encode_tokens(tokenizer, prompt) @@ -474,6 +478,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( mm_kwargs = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, False @@ -482,6 +487,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> MultiModalKwargs: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. @@ -498,6 +504,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return mm_kwargs diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index aefa1db24628d..31b25ef0bc731 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -150,6 +150,7 @@ class Qwen2AudioMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # NOTE - we rename audios -> audio in mm data because transformers has # deprecated audios for the qwen2audio processor and will remove @@ -174,6 +175,7 @@ class Qwen2AudioMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 899fc57c7a0e5..dc7b08c65bb13 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1027,11 +1027,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: + mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs) return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), - self.info._get_image_processor_kwargs(**mm_kwargs), + dict(**mm_kwargs, **tok_kwargs), ) def _get_prompt_updates( diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index fc29785af95a0..563650a4f162c 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -580,6 +580,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Drops anything between / tags; encoding with the tokenizer # will automatically add the image pads for the context. @@ -600,6 +601,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _hf_processor_applies_updates( @@ -607,6 +609,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 28f181dde2154..d362838dbb398 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -534,11 +534,13 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 94f5e03fd446e..5cccd6b8841b4 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -144,6 +144,7 @@ class UltravoxMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data.get("audios", []): @@ -165,10 +166,15 @@ class UltravoxMultiModalProcessor( item_processor_data = dict(**mm_data, audios=audios) + # some tokenizer kwargs are incompatible with UltravoxProcessor + tok_kwargs.pop("padding", None) + tok_kwargs.pop("truncation", None) + output = super()._call_hf_processor( prompt=prompt, mm_data=item_processor_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) output['audio_features'] = output.pop('audio_values') diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 5a0094fa749fd..568b81c4bbfa8 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -700,9 +700,10 @@ class WhisperMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: - feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + feature_extractor = self.info.get_feature_extractor() mm_data = dict(audio=mm_data.pop("audios")) mm_kwargs = dict( **mm_kwargs, @@ -712,6 +713,7 @@ class WhisperMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) if "labels" in processed_outputs: processed_outputs["input_ids"] = processed_outputs.pop("labels") diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 38f3a7cb932f4..aa7889fc3cc59 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1267,6 +1267,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # This refers to the data to be passed to HF processor. mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> "BatchFeature": """ Call the HF processor on the prompt text and @@ -1275,7 +1276,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), - mm_kwargs, + dict(**mm_kwargs, **tok_kwargs), ) def _hf_processor_applies_updates( @@ -1283,6 +1284,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: """ Return whether the HF processor applies prompt updates. @@ -1300,6 +1302,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data @@ -1313,6 +1316,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt_text, mm_data=processor_data, mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, ) processed_data.update(passthrough_data) @@ -1327,11 +1331,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, is_update_applied - def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: + def _apply_hf_processor_text_only( + self, prompt_text: str, + tokenization_kwargs: Mapping[str, object]) -> list[int]: """ Apply the HF processor on the prompt text only. @@ -1343,6 +1350,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt_text, mm_items=MultiModalDataItems({}), hf_processor_mm_kwargs={}, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids @@ -1368,6 +1376,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> MultiModalKwargs: """ Apply the HF processor on the multi-modal data only. @@ -1383,6 +1392,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return mm_kwargs @@ -1392,6 +1402,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: @@ -1412,15 +1423,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) - prompt_ids = self._apply_hf_processor_text_only(prompt) + prompt_ids = self._apply_hf_processor_text_only( + prompt, tokenization_kwargs) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) mm_kwargs = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, False @@ -1430,14 +1444,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache: ProcessingCache, mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ str, list[object]]]: model_id = self.info.model_id mm_cache_items = { modality: [ - cache.get_item(model_id, modality, item, - hf_processor_mm_kwargs) for item in items + cache.get_item( + model_id, modality, item, + dict(**hf_processor_mm_kwargs, **tokenization_kwargs)) + for item in items ] for modality, items in mm_data_items.items() } @@ -1457,10 +1474,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return mm_cache_items, mm_missing_data def _hash_mm_items( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalHashes: + self, mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: """Create MM hashes to be returned (only used in V1).""" model_id = self.info.model_id @@ -1468,7 +1484,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): modality: [ MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: item}, - **hf_processor_mm_kwargs) + **hf_processor_mm_kwargs, + **tokenization_kwargs) for item in items ] for modality, items in mm_items.items() @@ -1513,6 +1530,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -1524,10 +1542,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=True, ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs) + mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, + tokenization_kwargs) if return_mm_hashes else None) return prompt_ids, mm_kwargs, mm_hashes, is_update_applied @@ -1537,6 +1557,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -1552,6 +1573,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -1562,6 +1584,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache=cache, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, @@ -1575,6 +1598,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_items=self._to_mm_items(mm_missing_data), hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=False, ) @@ -1783,6 +1807,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -1800,6 +1825,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ mm_items = self._to_mm_items(mm_data) + if tokenization_kwargs is None: + tokenization_kwargs = {} + ( prompt_ids, mm_kwargs, @@ -1809,9 +1837,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt, mm_items, hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) + # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -1892,6 +1922,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: """ @@ -1906,6 +1937,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): encoder_prompt, mm_data, hf_processor_mm_kwargs, + tokenization_kwargs, return_mm_hashes, ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 67bcb31f23f70..fb5a7b64c4199 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -30,6 +30,7 @@ class ProcessorInputs: prompt: Union[str, list[int]] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) class DummyEncoderData(NamedTuple): @@ -90,8 +91,11 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): """ dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + tokenization_kwargs = {"truncation": False} - return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data) + return ProcessorInputs(prompt=dummy_text, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs) def _get_dummy_audios( self, @@ -170,6 +174,7 @@ class MultiModalProfiler(Generic[_I]): prompt=processor_inputs.prompt, mm_data=processor_inputs.mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, ) def _get_mm_num_tokens( diff --git a/vllm/utils.py b/vllm/utils.py index 7eb3c1e347cde..689102281c54f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1729,6 +1729,7 @@ def supports_kw( last_param = params[next(reversed(params))] # type: ignore return (last_param.kind == inspect.Parameter.VAR_KEYWORD and last_param.name != kw_name) + return False @@ -1771,6 +1772,7 @@ def resolve_mm_processor_kwargs( # Merge the final processor kwargs, prioritizing inference # time values over the initialization time values. mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs From 97d9524fe90ad5799cc11db4b4216fe3a30a07d6 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:15:24 -0400 Subject: [PATCH 011/195] [Refactor] Remove useless pdb comment (#20266) Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 321fb0351ad93..818f6d345ba6d 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -141,7 +141,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - # import pdb; pdb.set_trace() dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) From ded1fb635b7c1504a83fc7c195a5bf47d31c1bef Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Tue, 1 Jul 2025 07:45:14 +0800 Subject: [PATCH 012/195] [Bugfix][V1][P/D]Fix the issue of occasional garbled output for P2pNcclConnector (#20263) Signed-off-by: Abatom --- .../kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 81f7a2525896e..35c26897fe3f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -310,10 +310,11 @@ class P2pNcclEngine: elif data["cmd"] == "PUT": tensor_id = data["tensor_id"] try: - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) self.router_socket.send_multipart( [remote_address, b"0"]) comm, rank = self.comms[remote_address.decode()] From 6d42ce83155d42f04643c1fa54eaed8abf8170c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 30 Jun 2025 21:03:13 -0400 Subject: [PATCH 013/195] [CLI] Improve CLI arg parsing for `-O`/`--compilation-config` (#20156) Signed-off-by: luka --- tests/engine/test_arg_utils.py | 28 +++++++++------ tests/test_utils.py | 47 ++++++++++++++++++++++++ vllm/config.py | 19 +++++----- vllm/engine/arg_utils.py | 5 ++- vllm/utils.py | 65 ++++++++++++++++++++++++---------- 5 files changed, 124 insertions(+), 40 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index cfbc7c245ffd4..847f150bd6443 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -239,32 +239,40 @@ def test_compilation_config(): assert args.compilation_config == CompilationConfig() # set to O3 - args = parser.parse_args(["-O3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O0"]) + assert args.compilation_config.level == 0 # set to O 3 (space) - args = parser.parse_args(["-O", "3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O", "1"]) + assert args.compilation_config.level == 1 # set to O 3 (equals) - args = parser.parse_args(["-O=3"]) + args = parser.parse_args(["-O=2"]) + assert args.compilation_config.level == 2 + + # set to O.level 3 + args = parser.parse_args(["-O.level", "3"]) assert args.compilation_config.level == 3 # set to string form of a dict args = parser.parse_args([ - "--compilation-config", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor) # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor) def test_prefix_cache_default(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 913188455d8e6..36db8202ba622 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import asyncio import hashlib import json +import logging import pickle import socket from collections.abc import AsyncIterator @@ -142,6 +143,7 @@ def parser(): parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--hf-overrides', type=json.loads) + parser.add_argument('-O', '--compilation-config', type=json.loads) return parser @@ -265,6 +267,11 @@ def test_dict_args(parser): "val2", "--hf-overrides.key2.key4", "val3", + # Test compile config and compilation level + "-O.use_inductor=true", + "-O.backend", + "custom", + "-O1", # Test = sign "--hf-overrides.key5=val4", # Test underscore to dash conversion @@ -281,6 +288,13 @@ def test_dict_args(parser): "true", "--hf_overrides.key12.key13", "null", + # Test '-' and '.' in value + "--hf_overrides.key14.key15", + "-minus.and.dot", + # Test array values + "-O.custom_ops+", + "-quant_fp8", + "-O.custom_ops+=+silu_mul,-rms_norm", ] parsed_args = parser.parse_args(args) assert parsed_args.model_name == "something.something" @@ -301,7 +315,40 @@ def test_dict_args(parser): "key12": { "key13": None, }, + "key14": { + "key15": "-minus.and.dot", + } } + assert parsed_args.compilation_config == { + "level": 1, + "use_inductor": True, + "backend": "custom", + "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], + } + + +def test_duplicate_dict_args(caplog_vllm, parser): + args = [ + "--model-name=something.something", + "--hf-overrides.key1", + "val1", + "--hf-overrides.key1", + "val2", + "-O1", + "-O.level", + "2", + "-O3", + ] + + parsed_args = parser.parse_args(args) + # Should be the last value + assert parsed_args.hf_overrides == {"key1": "val2"} + assert parsed_args.compilation_config == {"level": 3} + + assert len(caplog_vllm.records) == 1 + assert "duplicate" in caplog_vllm.text + assert "--hf-overrides.key1" in caplog_vllm.text + assert "-O.level" in caplog_vllm.text # yapf: enable diff --git a/vllm/config.py b/vllm/config.py index 57b9df2364775..46a5bf34f66e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4140,9 +4140,9 @@ class CompilationConfig: @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": - """Parse the CLI value for the compilation config.""" - if cli_value in ["0", "1", "2", "3"]: - return cls(level=int(cli_value)) + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ return TypeAdapter(CompilationConfig).validate_json(cli_value) def __post_init__(self) -> None: @@ -4303,17 +4303,16 @@ class VllmConfig: """Quantization configuration.""" compilation_config: CompilationConfig = field( default_factory=CompilationConfig) - """`torch.compile` configuration for the model. + """`torch.compile` and cudagraph capture configuration for the model. - When it is a number (0, 1, 2, 3), it will be interpreted as the - optimization level. + As a shorthand, `-O` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O and -O= are supported as well but this will likely be + removed in favor of clearer -O syntax in the future. NOTE: level 0 is the default level without any optimization. level 1 and 2 are for internal testing only. level 3 is the recommended level for - production. - - Following the convention of traditional compilers, using `-O` without space - is also supported. `-O3` is equivalent to `-O 3`. + production, also default in V1. You can specify the full compilation config like so: `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c908f88b9a92..2d3783363c00b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: passed individually. For example, the following sets of arguments are equivalent:\n\n - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n + Additionally, list elements can be passed individually using '+': + - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n + - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n""" if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: diff --git a/vllm/utils.py b/vllm/utils.py index 689102281c54f..60e560c70ad3a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Chunked prefill for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( "Models with logits_soft_cap " @@ -752,7 +752,7 @@ def _generate_random_fp8( # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 - #-----|-------------|------------------- + # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops @@ -840,7 +840,6 @@ def create_kv_caches_with_random( seed: Optional[int] = None, device: Optional[str] = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" @@ -1205,7 +1204,6 @@ def deprecate_args( is_deprecated: Union[bool, Callable[[], bool]] = True, additional_message: Optional[str] = None, ) -> Callable[[F], F]: - if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) @@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: return weak_bound -#From: https://stackoverflow.com/a/4104188/2749989 +# From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: @@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser): # Convert underscores to dashes and vice versa in argument names processed_args = list[str]() - for arg in args: + for i, arg in enumerate(args): if arg.startswith('--'): if '=' in arg: key, value = arg.split('=', 1) @@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser): else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: - # allow -O flag to be used without space, e.g. -O3 - processed_args.append('-O') - processed_args.append(arg[2:]) + elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + # allow -O flag to be used without space, e.g. -O3 or -Odecode + # -O.<...> handled later + # also handle -O= here + level = arg[3:] if arg[2] == '=' else arg[2:] + processed_args.append(f'-O.level={level}') + elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { + "0", "1", "2", "3" + }: + # Convert -O to -O.level + processed_args.append('-O.level') else: processed_args.append(arg) @@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser): def recursive_dict_update( original: dict[str, Any], update: dict[str, Any], - ): - """Recursively updates a dictionary with another dictionary.""" + ) -> set[str]: + """Recursively updates a dictionary with another dictionary. + Returns a set of duplicate keys that were overwritten. + """ + duplicates = set[str]() for k, v in update.items(): if isinstance(v, dict) and isinstance(original.get(k), dict): - recursive_dict_update(original[k], v) + nested_duplicates = recursive_dict_update(original[k], v) + duplicates |= {f"{k}.{d}" for d in nested_duplicates} + elif isinstance(v, list) and isinstance(original.get(k), list): + original[k] += v else: + if k in original: + duplicates.add(k) original[k] = v + return duplicates delete = set[int]() dict_args = defaultdict[str, dict[str, Any]](dict) + duplicates = set[str]() for i, processed_arg in enumerate(processed_args): - if processed_arg.startswith("--") and "." in processed_arg: + if i in delete: # skip if value from previous arg + continue + + if processed_arg.startswith("-") and "." in processed_arg: if "=" in processed_arg: processed_arg, value_str = processed_arg.split("=", 1) if "." not in processed_arg: - # False positive, . was only in the value + # False positive, '.' was only in the value continue else: value_str = processed_args[i + 1] delete.add(i + 1) + if processed_arg.endswith("+"): + processed_arg = processed_arg[:-1] + value_str = json.dumps(list(value_str.split(","))) + key, *keys = processed_arg.split(".") try: value = json.loads(value_str) @@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser): # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - recursive_dict_update(dict_args[key], arg_dict) + arg_duplicates = recursive_dict_update(dict_args[key], + arg_dict) + duplicates |= {f'{key}.{d}' for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None processed_args = [ a for i, a in enumerate(processed_args) if i not in delete ] + if duplicates: + logger.warning("Found duplicate keys %s", ", ".join(duplicates)) + # Add the dict args back as if they were originally passed as JSON for dict_arg, dict_value in dict_args.items(): processed_args.append(dict_arg) @@ -2405,7 +2432,7 @@ def memory_profiling( The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa + """ # noqa gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() From e28533a16f73a4eae01c2b7b1b4ddf3fc1beedab Mon Sep 17 00:00:00 2001 From: fyuan1316 Date: Tue, 1 Jul 2025 09:30:14 +0800 Subject: [PATCH 014/195] [Bugfix] Fix include prompt in stream response when echo=true (#15233) Signed-off-by: Yuan Fang --- tests/entrypoints/openai/test_completion.py | 54 +++++++++++++++++++ vllm/entrypoints/openai/serving_completion.py | 21 ++++++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 7e54143f6e1c3..7933ca5cd6c6f 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,stream,echo", + [ + (MODEL_NAME, False, False), + (MODEL_NAME, False, True), + (MODEL_NAME, True, False), + (MODEL_NAME, True, True) # should not raise BadRequestError error + ], +) +async def test_echo_stream_completion(client: openai.AsyncOpenAI, + model_name: str, stream: bool, + echo: bool): + saying: str = "Hello, my name is" + result = await client.completions.create(model=model_name, + prompt=saying, + max_tokens=10, + temperature=0.0, + echo=echo, + stream=stream) + + stop_reason = "length" + + if not stream: + completion = result + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == stop_reason + + if echo: + assert choice.text is not None and saying in choice.text + else: + assert choice.text is not None and saying not in choice.text + + else: + chunks: list[str] = [] + final_finish_reason = None + async for chunk in result: + if chunk.choices and chunk.choices[0].text: + chunks.append(chunk.choices[0].text) + if chunk.choices and chunk.choices[0].finish_reason: + final_finish_reason = chunk.choices[0].finish_reason + + assert final_finish_reason == stop_reason + content = "".join(chunks) + if echo: + assert content is not None and saying in content + else: + assert content is not None and saying not in content diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a19fde8d70a83..8171b491aafcc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -25,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ErrorResponse, RequestResponseMetadata, UsageInfo) -# yapf: enable +from vllm.entrypoints.openai.serving_engine import ( + EmbedsPrompt as ServingEngineEmbedsPrompt) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + TextTokensPrompt, clamp_prompt_logprobs, is_text_tokens_prompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) @@ -223,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing): if stream: return self.completion_stream_generator( request, + request_prompts, result_generator, request_id, created_time, @@ -285,6 +289,8 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, + request_prompts: list[Union[TextTokensPrompt, + ServingEngineEmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -313,7 +319,15 @@ class OpenAIServingCompletion(OpenAIServing): async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs - prompt_text = res.prompt + + if res.prompt is not None: + prompt_text = res.prompt + else: + request_prompt = request_prompts[prompt_idx] + if is_text_tokens_prompt(request_prompt): + prompt_text = request_prompt["prompt"] + else: + prompt_text = None # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: @@ -336,14 +350,13 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = prompt_token_ids out_logprobs = prompt_logprobs else: - assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ *prompt_token_ids, *output.token_ids ] out_logprobs = [ - *prompt_logprobs, + *(prompt_logprobs or []), *(output.logprobs or []), ] has_echoed[i] = True From 7151f92241db1bb6ef4eb0fcfed87256646d554e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 21:01:48 -0700 Subject: [PATCH 015/195] [Misc] Fix spec decode example (#20296) Signed-off-by: Woosuk Kwon --- examples/offline_inference/spec_decode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 90d103e5cb05d..3f38aa9fcaa60 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -79,9 +79,7 @@ def main(): trust_remote_code=True, tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, - max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, From 92ee7baaf9a5bf6c8132dde56e4056933c61f50f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 30 Jun 2025 21:03:55 -0700 Subject: [PATCH 016/195] [Example] add one-click runnable example for P2P NCCL XpYd (#20246) Signed-off-by: KuntaiDu --- .../disagg_example_p2p_nccl_xpyd.sh | 245 ++++++++++++++++++ .../disagg_proxy_p2p_nccl_xpyd.py} | 0 2 files changed, 245 insertions(+) create mode 100644 examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh rename examples/online_serving/{disagg_xpyd/disagg_prefill_proxy_xpyd.py => disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py} (100%) diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh new file mode 100644 index 0000000000000..2966f386c93a3 --- /dev/null +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Disaggregated Serving Script - P2P NCCL XpYd Architecture +# ============================================================================= +# This script demonstrates disaggregated prefill and decode serving using +# P2P NCCL communication. The architecture supports various XpYd configurations: +# +# - 1P3D: 1 Prefill server + 3 Decode servers (current default) +# - 3P1D: 3 Prefill servers + 1 Decode server +# - etc. +# +# Configuration can be customized via environment variables: +# MODEL: Model to serve +# PREFILL_GPUS: Comma-separated GPU IDs for prefill servers +# DECODE_GPUS: Comma-separated GPU IDs for decode servers +# PREFILL_PORTS: Comma-separated ports for prefill servers +# DECODE_PORTS: Comma-separated ports for decode servers +# PROXY_PORT: Proxy server port used to setup XpYd connection. +# TIMEOUT_SECONDS: Server startup timeout +# ============================================================================= + +# Configuration - can be overridden via environment variables +MODEL=${MODEL:-meta-llama/Llama-3.1-8B-Instruct} +TIMEOUT_SECONDS=${TIMEOUT_SECONDS:-1200} +PROXY_PORT=${PROXY_PORT:-30001} + +# Default 1P3D configuration (1 Prefill + 3 Decode) +PREFILL_GPUS=${PREFILL_GPUS:-0} +DECODE_GPUS=${DECODE_GPUS:-1,2,3} +PREFILL_PORTS=${PREFILL_PORTS:-20003} +DECODE_PORTS=${DECODE_PORTS:-20005,20007,20009} + +echo "Warning: P2P NCCL disaggregated prefill XpYd support for vLLM v1 is experimental and subject to change." +echo "" +echo "Architecture Configuration:" +echo " Model: $MODEL" +echo " Prefill GPUs: $PREFILL_GPUS, Ports: $PREFILL_PORTS" +echo " Decode GPUs: $DECODE_GPUS, Ports: $DECODE_PORTS" +echo " Proxy Port: $PROXY_PORT" +echo " Timeout: ${TIMEOUT_SECONDS}s" +echo "" + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_required_files() { + local files=("disagg_proxy_p2p_nccl_xpyd.py") + for file in "${files[@]}"; do + if [[ ! -f "$file" ]]; then + echo "Required file $file not found in $(pwd)" + exit 1 + fi + done +} + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + echo "Example: export HF_TOKEN=your_token_here" + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # Check if the number of GPUs are >=2 via nvidia-smi + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + if ! python3 -c "import $1" > /dev/null 2>&1; then + echo "$1 is not installed. Please install it via pip install $1." + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == "this whole process-group" + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=$TIMEOUT_SECONDS + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + echo "Server on port $port is ready." + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server on port $port" + return 1 + fi + + sleep 1 + done +} + +main() { + check_required_files + check_hf_token + check_num_gpus + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + ensure_python_library_installed quart + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching disaggregated serving components..." + echo "Please check the log files for detailed output:" + echo " - prefill*.log: Prefill server logs" + echo " - decode*.log: Decode server logs" + echo " - proxy.log: Proxy server log" + + # ============================================================================= + # Launch Proxy Server + # ============================================================================= + echo "" + echo "Starting proxy server on port $PROXY_PORT..." + python3 disagg_proxy_p2p_nccl_xpyd.py & + PIDS+=($!) + + # Parse GPU and port arrays + IFS=',' read -ra PREFILL_GPU_ARRAY <<< "$PREFILL_GPUS" + IFS=',' read -ra DECODE_GPU_ARRAY <<< "$DECODE_GPUS" + IFS=',' read -ra PREFILL_PORT_ARRAY <<< "$PREFILL_PORTS" + IFS=',' read -ra DECODE_PORT_ARRAY <<< "$DECODE_PORTS" + + # ============================================================================= + # Launch Prefill Servers (X Producers) + # ============================================================================= + echo "" + echo "Starting ${#PREFILL_GPU_ARRAY[@]} prefill server(s)..." + for i in "${!PREFILL_GPU_ARRAY[@]}"; do + local gpu_id=${PREFILL_GPU_ARRAY[$i]} + local port=${PREFILL_PORT_ARRAY[$i]} + local kv_port=$((21001 + i)) + + echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" + CUDA_VISIBLE_DEVICES=$gpu_id VLLM_USE_V1=1 vllm serve $MODEL \ + --enforce-eager \ + --host 0.0.0.0 \ + --port $port \ + --tensor-parallel-size 1 \ + --seed 1024 \ + --dtype float16 \ + --max-model-len 10000 \ + --max-num-batched-tokens 10000 \ + --max-num-seqs 256 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --disable-log-request \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$kv_port\",\"kv_connector_extra_config\":{\"proxy_ip\":\"0.0.0.0\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$port\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" > prefill$((i+1)).log 2>&1 & + PIDS+=($!) + done + + # ============================================================================= + # Launch Decode Servers (Y Decoders) + # ============================================================================= + echo "" + echo "Starting ${#DECODE_GPU_ARRAY[@]} decode server(s)..." + for i in "${!DECODE_GPU_ARRAY[@]}"; do + local gpu_id=${DECODE_GPU_ARRAY[$i]} + local port=${DECODE_PORT_ARRAY[$i]} + local kv_port=$((22001 + i)) + + echo " Decode server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" + VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --enforce-eager \ + --host 0.0.0.0 \ + --port $port \ + --tensor-parallel-size 1 \ + --seed 1024 \ + --dtype float16 \ + --max-model-len 10000 \ + --max-num-batched-tokens 10000 \ + --max-num-seqs 256 \ + --trust-remote-code \ + --gpu-memory-utilization 0.7 \ + --disable-log-request \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$kv_port\",\"kv_connector_extra_config\":{\"proxy_ip\":\"0.0.0.0\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$port\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" > decode$((i+1)).log 2>&1 & + PIDS+=($!) + done + + # ============================================================================= + # Wait for All Servers to Start + # ============================================================================= + echo "" + echo "Waiting for all servers to start..." + for port in "${PREFILL_PORT_ARRAY[@]}" "${DECODE_PORT_ARRAY[@]}"; do + if ! wait_for_server $port; then + echo "Failed to start server on port $port" + cleanup + exit 1 + fi + done + + echo "" + echo "All servers are up. Starting benchmark..." + + # ============================================================================= + # Run Benchmark + # ============================================================================= + cd ../../../benchmarks/ + python3 benchmark_serving.py --port 10001 --seed $(date +%s) \ + --model $MODEL \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 2 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup +} + +main \ No newline at end of file diff --git a/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py similarity index 100% rename from examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py rename to examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py From a2f14dc8f9bb04bd782d1aa4d2e6364841d63d6c Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Mon, 30 Jun 2025 23:17:07 -0500 Subject: [PATCH 017/195] [CI][Intel Gaudi][vllm-Plugin]Add CI for hpu-plugin-v1-test (#20196) Signed-off-by: Chendi Xue --- .../scripts/hardware_ci/run-hpu-test.sh | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index 5efac3ddf469f..ae5b35a9ac6bd 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -2,10 +2,34 @@ # This script build the CPU docker image and run the offline inference inside the container. # It serves a sanity check for compilation and basic model usage. -set -ex +set -exuo pipefail # Try building the docker image -docker build -t hpu-test-env -f docker/Dockerfile.hpu . +cat <&2 +fi + +# The trap will handle the container removal and final exit. \ No newline at end of file From bd5038af076a2e299d4781c3885415639a1ed3a5 Mon Sep 17 00:00:00 2001 From: Ernest Wong Date: Mon, 30 Jun 2025 21:44:39 -0700 Subject: [PATCH 018/195] [Doc] add config and troubleshooting guide for NCCL & GPUDirect RDMA (#15897) Signed-off-by: Ernest Wong --- docs/serving/distributed_serving.md | 45 ++++++++++++++++++++++++++++- docs/usage/troubleshooting.md | 21 ++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/docs/serving/distributed_serving.md b/docs/serving/distributed_serving.md index 38dcb8c81caf7..6665955411ad5 100644 --- a/docs/serving/distributed_serving.md +++ b/docs/serving/distributed_serving.md @@ -100,7 +100,50 @@ vllm serve /path/to/the/model/in/the/container \ --tensor-parallel-size 16 ``` -To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like Infiniband. To correctly set up the cluster to use Infiniband, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the Infiniband is working is to run vLLM with `NCCL_DEBUG=TRACE` environment variable set, e.g. `NCCL_DEBUG=TRACE vllm serve ...` and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find `[send] via NET/IB/GDRDMA` in the logs, it means NCCL uses Infiniband with GPU-Direct RDMA, which is efficient. +To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like InfiniBand. To correctly set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the InfiniBand is working is to run vLLM with `NCCL_DEBUG=TRACE` environment variable set, e.g. `NCCL_DEBUG=TRACE vllm serve ...` and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find `[send] via NET/IB/GDRDMA` in the logs, it means NCCL uses InfiniBand with GPUDirect RDMA, which is efficient. + +### GPUDirect RDMA + +To enable GPUDirect RDMA with vLLM, specific configuration tweaks are needed. This setup ensures: + +- `IPC_LOCK` Security Context: Add the `IPC_LOCK` capability to the container’s security context to lock memory pages and prevent swapping to disk. +- Shared Memory with `/dev/shm`: Mount `/dev/shm` in the pod spec to provide shared memory for IPC. + +When using Docker, you can set up the container as follows: + +```bash +docker run --gpus all \ + --ipc=host \ + --shm-size=16G \ + -v /dev/shm:/dev/shm \ + vllm/vllm-openai +``` + +When using Kubernetes, you can set up the pod spec as follows: + +```yaml +... +spec: + containers: + - name: vllm + image: vllm/vllm-openai + securityContext: + capabilities: + add: ["IPC_LOCK"] + volumeMounts: + - mountPath: /dev/shm + name: dshm + resources: + limits: + nvidia.com/gpu: 8 + requests: + nvidia.com/gpu: 8 + volumes: + - name: dshm + emptyDir: + medium: Memory +... +``` !!! warning After you start the Ray cluster, you'd better also check the GPU-GPU communication between nodes. It can be non-trivial to set up. Please refer to the [sanity check script][troubleshooting-incorrect-hardware-driver] for more information. If you need to set some environment variables for the communication configuration, you can append them to the `run_cluster.sh` script, e.g. `-e NCCL_SOCKET_IFNAME=eth0`. Note that setting environment variables in the shell (e.g. `NCCL_SOCKET_IFNAME=eth0 vllm serve ...`) only works for the processes in the same node, not for the processes in the other nodes. Setting environment variables when you create the cluster is the recommended way. See for more information. diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 82957d33b19e0..7f1f76ce3d2e3 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -273,6 +273,27 @@ But you are sure that the model is in the [list of supported models][supported-m If you see an error like `RuntimeError: Failed to infer device type`, it means that vLLM failed to infer the device type of the runtime environment. You can check [the code](gh-file:vllm/platforms/__init__.py) to see how vLLM infers the device type and why it is not working as expected. After [this PR](gh-pr:14195), you can also set the environment variable `VLLM_LOGGING_LEVEL=DEBUG` to see more detailed logs to help debug the issue. +## NCCL error: unhandled system error during `ncclCommInitRank` + +If your serving workload uses GPUDirect RDMA for distributed serving across multiple nodes and encounters an error during `ncclCommInitRank`, with no clear error message even with `NCCL_DEBUG=INFO` set, it might look like this: + +```text +Error executing method 'init_device'. This might cause deadlock in distributed execution. +Traceback (most recent call last): +... + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 99, in __init__ + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 277, in ncclCommInitRank + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 256, in NCCL_CHECK + raise RuntimeError(f"NCCL error: {error_str}") + RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details) +... +``` + +This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving. + ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). From 27949354faa06035645aa908cc73922500a80b17 Mon Sep 17 00:00:00 2001 From: Alex Kogan <82225080+sakogan@users.noreply.github.com> Date: Tue, 1 Jul 2025 01:44:38 -0400 Subject: [PATCH 019/195] [Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference (#18768) Signed-off-by: Alex Kogan Co-authored-by: Michael Goin --- tests/quantization/test_rtn.py | 28 ++ .../layers/quantization/__init__.py | 3 + .../model_executor/layers/quantization/rtn.py | 288 ++++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 tests/quantization/test_rtn.py create mode 100644 vllm/model_executor/layers/quantization/rtn.py diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py new file mode 100644 index 0000000000000..04c1f98a709e2 --- /dev/null +++ b/tests/quantization/test_rtn.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright © 2025, Oracle and/or its affiliates. +"""Tests RTN quantization startup and generation, +doesn't test correctness +""" +import pytest + +from tests.quantization.utils import is_quant_method_supported + +MODELS = ["microsoft/Phi-3-mini-4k-instruct"] + + +@pytest.mark.skipif(not is_quant_method_supported("rtn"), + reason="RTN is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_model_rtn_startup( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1cb23e7a18875..60217ee86ad1d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -35,6 +35,7 @@ QuantizationMethods = Literal[ "moe_wna16", "torchao", "auto-round", + "rtn", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig + from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, "auto-round": AutoRoundConfig, + "rtn": RTNConfig } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py new file mode 100644 index 0000000000000..7e7fd6d51fd32 --- /dev/null +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright © 2025, Oracle and/or its affiliates. + +import os +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) +"""By default, use 8 bit as target precision, but it can be +overridden by setting the RTN_NUM_BITS envvar +""" +NUM_BITS = os.getenv('RTN_NUM_BITS', "8") +"""By default, use group size of 128 parameters, but it can be +overridden by setting the RTN_GROUP_SIZE envvar +""" +GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") + + +class RTNConfig(QuantizationConfig): + """Config class for RTN. + """ + + def __init__( + self, + weight_bits: int = int(NUM_BITS), + group_size: int = int(GROUP_SIZE), + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + if self.weight_bits != 4 and self.weight_bits != 8: + raise ValueError( + "Currently, only 4-bit or 8-bit weight quantization is " + f"supported for RTN, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"RTNConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "rtn" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "RTNConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["RTNLinearMethod"]: + if isinstance(layer, LinearBase): + return RTNLinearMethod(self) + return None + + +class RTNTensor: + """A wrapper over Tensor that enables quantization on-the-fly by + overloading the copy_ method. + """ + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.data = data + self.scale = scale + self.quant_config = quant_config + + def narrow(self, dim, start, length): + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return RTNTensor( + self.data.narrow(dim, start // factor, length // factor), + self.scale.narrow(dim, start, length), self.quant_config) + + @property + def shape(self): + shape = self.data.shape + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return torch.Size((shape[0] * factor, shape[1])) + + def copy_(self, loaded_weight: torch.Tensor) -> None: + qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), + self.quant_config.weight_bits, + self.quant_config.group_size) + + self.data.copy_(qweight) + self.scale.data.copy_(weight_scale) + + +class RTNParameter(Parameter): + """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor) + when its data is accessed. We need this wrapper for the data loading phase + only, so we can intercept a weight copying function (torch.Tensor.copy_) + and apply quantization on-the-fly. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.scale = scale + self.quant_config = quant_config + + @property + def data(self): + return RTNTensor(super().data, self.scale, self.quant_config) + + +class RTNLinearMethod(LinearMethodBase): + """Linear method for RTN. + + Args: + quant_config: The RTN quantization config. + """ + + def __init__(self, quant_config: RTNConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + num_groups_per_col = (input_size_per_partition // + self.quant_config.group_size + if self.quant_config.group_size != -1 else 1) + + scale = Parameter( + torch.empty(output_size_per_partition, + num_groups_per_col, + dtype=params_dtype), + requires_grad=False, + ) + factor = 1 if self.quant_config.weight_bits == 8 else 2 + + weight = RTNParameter(data=torch.empty(output_size_per_partition // + factor, + input_size_per_partition, + dtype=torch.int8), + scale=scale, + quant_config=self.quant_config) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) + + layer.register_parameter("scale", scale) + layer.output_size_per_partition = output_size_per_partition + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """torch.compile does not know how to deal with a Parameter subclass + (aka RTNParameter). As we don't really need RTNParameters for the + forward pass, we replace them with equivalent instances of Parameters. + """ + old_weight = layer.weight + assert isinstance(old_weight, RTNParameter) + data = old_weight.data.data + + delattr(layer, "weight") + + new_weight = Parameter(data=data, requires_grad=False) + layer.register_parameter("weight", new_weight) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scale = layer.scale + + weight = rtn_dequantize(qweight, scale) + out = F.linear(x, weight) + del weight + if bias is not None: + out.add_(bias) + + return out + + +def rtn_quantize(tensor: torch.Tensor, num_bits: int, + group_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor using per-group static scaling factor. + + Args: + tensor: The input tensor. + num_bits: Target precision for the result (supported values are + 8 or 4). + group_size: Quantization granularity. + If equal to -1, each row in the input tensor is treated + as one group. + """ + + q_range = 2**num_bits + num_groups = (tensor.shape[0] * tensor.shape[1] // + group_size if group_size != -1 else tensor.shape[0]) + """Calculate a scaling factor per input group. + """ + input_flat = tensor.reshape(num_groups, -1) + input_min = torch.min(input_flat, dim=1, keepdim=True)[0] + input_max = torch.max(input_flat, dim=1, keepdim=True)[0] + input_max_abs = torch.max(input_min.abs(), input_max.abs()) + scale = (input_max_abs * 2.0 / (q_range - 1)) + """Scale each input group, truncate and round to the nearest integer. + """ + scaled_input = input_flat / scale + scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) + scaled_input = scaled_input.round() + + scale = scale.reshape(tensor.shape[0], -1).contiguous() + inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) + inputs_q = inputs_q.contiguous() + + if num_bits == 4: + """Pack two 4-bit values into each byte. + """ + inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) + inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) + inputs_q = inputs_q.contiguous() + + return inputs_q, scale + + +def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor using per-group static scaling factors. + + Args: + tensor: The input tensor. + scale: The tensor with per-group scale factors. + """ + + num_groups = scale.size(0) * scale.size(1) + input_dim, output_dim = tensor.shape + + num_bits = 8 if input_dim == scale.size(0) else 4 + if num_bits == 4: + input_dim *= 2 + + data = torch.empty((input_dim, output_dim), + dtype=scale.dtype, + device=tensor.device) + + if num_bits == 8: + data.copy_(tensor) + else: + """Unpack two 4-bit values from each byte. + """ + tensor = tensor.reshape(input_dim, output_dim // 2) + for i in range(2): + data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 + """Scale each input group with its scaling factor. + """ + scale = scale.reshape(num_groups, -1) + data = data.reshape(num_groups, -1) + data = torch.mul(data, scale) + + input_deq = data.reshape((input_dim, output_dim)).contiguous() + return input_deq From be250bbc67973766e546e0e3d8abb21e5caa2b1f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 1 Jul 2025 15:02:09 +0900 Subject: [PATCH 020/195] [V1] Only print cudagraph tqdm on rank 0 with `is_global_first_rank` (#19516) Signed-off-by: mgoin --- vllm/distributed/parallel_state.py | 31 ++++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 11 +++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 50dbbf50e9fcf..c53601a22f215 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], return [x == 1 for x in aggregated_data.tolist()] +def is_global_first_rank() -> bool: + """ + Check if the current process is the first rank globally across all + parallelism strategies (PP, TP, DP, EP, etc.). + + Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` + or `get_pp_group().is_first_rank`, this function checks the global rank + across all parallelism dimensions. + + Returns: + bool: True if this is the global first rank (rank 0), False otherwise. + Returns True if distributed is not initialized (single process). + """ + try: + # If world group is available, use it for the most accurate check + global _WORLD + if _WORLD is not None: + return _WORLD.is_first_rank + + # If torch distributed is not initialized, assume single process + if not torch.distributed.is_initialized(): + return True + + # Fallback to torch's global rank + return torch.distributed.get_rank() == 0 + + except Exception: + # If anything goes wrong, assume this is the first rank + return True + + def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: """ Returns the total number of nodes in the process group. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29d39de212f88..5bdaf4b969e70 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,7 +26,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, + get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) @@ -2285,9 +2285,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): full_cg = self.full_cuda_graph - for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), - desc="Capturing CUDA graphs", - total=len(self.cudagraph_batch_sizes)): + # Only rank 0 should print progress bar during capture + compilation_cases = reversed(self.cudagraph_batch_sizes) + if is_global_first_rank(): + compilation_cases = tqdm(list(compilation_cases), + desc="Capturing CUDA graph shapes") + for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): From 86debab54c046232014b108d530a8c25d857e9a3 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 1 Jul 2025 00:48:10 -0600 Subject: [PATCH 021/195] Fix `numel()` downcast in vllm/csrc/moe/moe_align_sum_kernels.cu +2 (#17082) Co-authored-by: mgoin --- csrc/moe/moe_align_sum_kernels.cu | 2 +- csrc/moe/topk_softmax_kernels.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 9335e2333b0d9..462dbd1f8b380 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -239,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); - const int num_tokens = output.numel() / hidden_size; + const auto num_tokens = output.numel() / hidden_size; const int topk = input.size(1); dim3 grid(num_tokens); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index dea5b1f21ec27..064b76c9cd427 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -492,7 +492,7 @@ void topk_softmax( torch::Tensor& gating_output) // [num_tokens, num_experts] { const int num_experts = gating_output.size(-1); - const int num_tokens = gating_output.numel() / num_experts; + const auto num_tokens = gating_output.numel() / num_experts; const int topk = topk_weights.size(-1); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); From 22e9d42040f3ecf83da181cfd84ab4cea000c4af Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 1 Jul 2025 00:02:20 -0700 Subject: [PATCH 022/195] [Misc] add xgrammar for arm64 (#18359) Signed-off-by: Prashant Gupta --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 6cc304e5b1f6d..97a35e05d38ab 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -23,7 +23,7 @@ lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs From 9909726d2a30d834d97efd7bf1c4fc0e52fa48b5 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 1 Jul 2025 00:12:20 -0700 Subject: [PATCH 023/195] Enable ZP Support for Machete (#20268) Signed-off-by: czhu-cohere --- benchmarks/kernels/benchmark_machete.py | 2 ++ tests/kernels/quantization/test_machete_mm.py | 2 +- .../kernels/mixed_precision/machete.py | 20 +++++++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0f896f187ecb9..f73d0511e01fc 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: fn = lambda: ops.gptq_marlin_gemm( a=bt.a, + c=None, b_q_weight=w_q, b_scales=w_s, + global_scale=None, b_zeros=w_zp, g_idx=g_idx, perm=sort_indices, diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index 998171baaf2de..a4fb9874c4906 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: - return group_size is None or group_size == -1 or group_size % shape[2] == 0 + return group_size is None or group_size == -1 or shape[2] % group_size == 0 def machete_quantize_and_pack(atype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index c7c45861875af..a75f3ac8d5033 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel): return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): @@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel): # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config @@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel): x.data = x.data.contiguous() return x + def transform_w_zp(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=1) + w_s = getattr(layer, self.w_s_name).data + # pre-apply scales to zero-points + x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() + return x + # Repack weights and scales for Machete self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if c.zero_points: + self._transform_param(layer, self.w_zp_name, transform_w_zp) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config - w_q, w_s, _, _ = self._get_weight_params(layer) + w_q, w_s, w_zp, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) @@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel): output = ops.machete_mm(a=x_2d, b_q=w_q, b_type=c.weight_type, - b_group_zeros=None, + b_group_zeros=w_zp, b_group_scales=w_s, b_group_size=c.group_size) From 6cc1e7d96dab6b9c344ec87dec6dc9ab07ad5d21 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Tue, 1 Jul 2025 15:25:03 +0800 Subject: [PATCH 024/195] [CPU] Update custom ops for the CPU backend (#20255) Signed-off-by: jiang1.li --- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 20 + csrc/cpu/sgl-kernels/common.h | 238 +++ csrc/cpu/sgl-kernels/gemm.cpp | 464 ++++++ csrc/cpu/sgl-kernels/gemm.h | 266 ++++ csrc/cpu/sgl-kernels/gemm_fp8.cpp | 530 +++++++ csrc/cpu/sgl-kernels/gemm_int8.cpp | 440 ++++++ csrc/cpu/sgl-kernels/moe.cpp | 1330 +++++++++++++++++ csrc/cpu/sgl-kernels/moe_fp8.cpp | 502 +++++++ csrc/cpu/sgl-kernels/moe_int8.cpp | 769 ++++++++++ csrc/cpu/sgl-kernels/vec.h | 308 ++++ csrc/cpu/shm.cpp | 178 +-- csrc/cpu/torch_bindings.cpp | 43 + docs/getting_started/installation/cpu.md | 1 + .../models/language/generation/test_common.py | 3 +- vllm/_custom_ops.py | 49 + vllm/envs.py | 5 + .../layers/fused_moe/cpu_fused_moe.py | 214 +++ vllm/model_executor/layers/fused_moe/layer.py | 41 +- vllm/model_executor/layers/linear.py | 25 +- vllm/model_executor/layers/utils.py | 25 +- .../layers/vocab_parallel_embedding.py | 2 +- vllm/platforms/cpu.py | 2 + 23 files changed, 5357 insertions(+), 101 deletions(-) create mode 100644 csrc/cpu/sgl-kernels/common.h create mode 100644 csrc/cpu/sgl-kernels/gemm.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm.h create mode 100644 csrc/cpu/sgl-kernels/gemm_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/vec.h create mode 100644 vllm/model_executor/layers/fused_moe/cpu_fused_moe.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 8db8c3a05fb30..42506730e868c 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -51,6 +51,7 @@ function cpu_tests() { pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model pytest -v -s tests/models/language/generation -m cpu_model + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -98,4 +99,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5cd2c98f23438..264c970ef784a 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -96,12 +96,21 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + set(ENABLE_AVX512BF16 ON) else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") endif() else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + + find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) + if (AVX512VNNI_FOUND) + list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") + set(ENABLE_AVX512VNNI ON) + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -231,6 +240,17 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) + if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) + set(VLLM_EXT_SRC + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + ${VLLM_EXT_SRC}) + add_compile_definitions(-DCPU_CAPABILITY_AVX512) + endif() elseif(POWER10_FOUND) set(VLLM_EXT_SRC "csrc/cpu/quant.cpp" diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h new file mode 100644 index 0000000000000..20261c1ef3e87 --- /dev/null +++ b/csrc/cpu/sgl-kernels/common.h @@ -0,0 +1,238 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +#include +#include +#include + +// clang-format off + +#if defined(_OPENMP) +#include +#endif + +namespace { + +// dispatch bool +#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case at::ScalarType::BFloat16 : { \ + using packed_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using packed_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Char : { \ + using packed_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn : { \ + using packed_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + }() + +#define UNUSED(x) (void)(x) + +#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + +#define CHECK_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +// parallel routines +constexpr int GRAIN_SIZE = 1024; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); +} +#else + f(0, m, 0, n); +#endif +} + +template +int get_cache_blocks(int BLOCK_SIZE, int K) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); +} + +// data indexing for dimension collapse +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// forced unroll for perf critical path + +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // anonymous namespace diff --git a/csrc/cpu/sgl-kernels/gemm.cpp b/csrc/cpu/sgl-kernels/gemm.cpp new file mode 100644 index 0000000000000..c122d07185ddb --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.cpp @@ -0,0 +1,464 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// packed layout: +// quants {N, K} int8_t +// comp {N} int32_t +template +inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + __m512i vcomp[COLS]; + + for (int col = 0; col < COLS; ++col) { + vcomp[col] = _mm512_setzero_si512(); + } + + const int64_t offset = BLOCK_N * K; + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < K / 4; ++k) { + for (int col = 0; col < COLS; ++col) { + __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); + vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); + } + } + + for (int col = 0; col < COLS; ++col) { + _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); + } +#else + TORCH_CHECK(false, "s8s8_compensation not implemented!"); +#endif +} + +// convert to vnni format +// from [N, K] to [K/2, N, 2] for bfloat16 and float16 +template +inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { + const int VNNI_BLK = 2; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } +} + +template <> +inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + TORCH_CHECK(N == BLOCK_N); + + const int VNNI_BLK = 4; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } + s8s8_compensation(packed, K); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + // for COLS = 1, 3 use 256bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + row * ldc + col * 16), + (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, const float* __restrict__ bias, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm( + M, N, K, lda, ldb, BLOCK_N, /* add_C */false, + A, B, Ctmp); + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + if (brg) { + brgemm::apply( + A, B, C, Ctmp, bias, + M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void weight_packed_linear_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx + const bool use_brgemm = (M > 4) || (!std::is_same_v); + + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, + /* C */ out + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + }}} + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \ + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \ + int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor convert_weight_packed(at::Tensor& weight) { + // for 3d moe weights + // weight : [E, OC, IC] + // w1 : [E, 2N, K] + // w2 : [E, K, N] + CHECK_INPUT(weight); + + const int64_t ndim = weight.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = weight.scalar_type(); + const int64_t E = ndim == 3 ? weight.size(0) : 1; + const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); + const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + + // we handle 2 TILE_N at a time. + TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); + TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t NB = div_up(OC, BLOCK_N); + + // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] + auto packed_weight = at::empty({}, weight.options()); + const int64_t stride = OC * IC; + + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, + "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + + CPU_DISPATCH_PACKED_TYPES(st, [&] { + // adjust most inner dimension size + const int packed_row_size = get_row_size(IC); + auto sizes = weight.sizes().vec(); + sizes[ndim - 1] = packed_row_size; + packed_weight.resize_(sizes); + + const packed_t* w_data = weight.data_ptr(); + packed_t* packed_data = packed_weight.data_ptr(); + + // parallel on {E, NB} + at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}; + data_index_init(begin, e, E, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int64_t n = nb * BLOCK_N; + int64_t n_size = std::min(BLOCK_N, OC - n); + pack_vnni( + packed_data + e * OC * packed_row_size + n * packed_row_size, + w_data + e * stride + n * IC, + n_size, + IC); + + // move to the next index + data_index_step(e, E, nb, NB); + } + }); + }); + return packed_weight; +} + +// mat1 : [M, K] +// mat2 : [N, K] +// bias : [N] +// out : [M, N] +// +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + auto out = at::empty({M, N}, mat1.options()); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { + weight_packed_linear_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h new file mode 100644 index 0000000000000..afae19721ae96 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -0,0 +1,266 @@ +#pragma once + +#include + +// clang-format off + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { return 2 * TILE_M; } +constexpr int block_size_n() { return 2 * TILE_N; } + +// define threshold using brgemm (intel AMX) +template inline bool can_use_brgemm(int M); +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return true; } +// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +template <> inline bool can_use_brgemm(int M) { return false; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } + +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// adjust leading dimension size for K +template +inline int64_t get_row_size(int64_t K) { + return K; +} + +template <> +inline int64_t get_row_size(int64_t K) { + return K + sizeof(int32_t); +} + +inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { + return use_int8_w8a8 ? K + sizeof(int32_t) : K; +} + +// pack weight to vnni format +at::Tensor convert_weight_packed(at::Tensor& weight); + +// moe implementations for int8 w8a8 +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for fp8 w8a16 +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for int4 w4a16 +template +void fused_experts_int4_w4a16_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::quint4x2* __restrict__ packed_w1, + const at::quint4x2* __restrict__ packed_w2, + const uint8_t* __restrict__ w1z, + const uint8_t* __restrict__ w2z, + const scalar_t* __restrict__ w1s, + const scalar_t* __restrict__ w2s, + int group_size, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// shared expert implememntation for int8 w8a8 +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::quint4x2* __restrict__ B, + scalar_t* __restrict__ C, + const uint8_t* __restrict__ Bz, + const scalar_t* __restrict__ Bs, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs, + bool brg); + +// TODO: debug print, remove me later +inline void print_16x32i(const __m512i x) { + int32_t a[16]; + _mm512_storeu_si512((__m512i *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + +inline void print_16x32(const __m512 x) { + float a[16]; + _mm512_storeu_ps((__m512 *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + + +inline void print_32x8u(const __m256i x) { + uint8_t a[32]; + _mm256_storeu_si256((__m256i *)a, x); + + for (int i = 0; i < 32; ++i) { + std::cout << int32_t(a[i]) << " "; + } + std::cout << std::endl; +} diff --git a/csrc/cpu/sgl-kernels/gemm_fp8.cpp b/csrc/cpu/sgl-kernels/gemm_fp8.cpp new file mode 100644 index 0000000000000..b5f2f07bad623 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_fp8.cpp @@ -0,0 +1,530 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int K2 = K >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const uint16_t* b_ptr = reinterpret_cast(packed_B); + const __m512 vd = _mm512_set1_ps(scale); + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + +#pragma GCC unroll 4 + for (int k = 0; k < K2; ++k) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int KB = div_up(K, BLOCK_K); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vsum[ROWS * COLS]; + + // block quant scale + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int lda2 = lda >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint16_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); + } + } + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); + }; + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + if constexpr (PREFETCH_SIZE_KB > 0) { + _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); + } + // 2. zero vsum for each block + Unroll{}([&](auto i) { + vsum[i] = _mm512_setzero_ps(); + }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { + vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); + }); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); + } +}; + +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); + + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + } + + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + + if (brg) { + brgemm::apply( + A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fp8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const at::Float8_e4m3fn* __restrict__ mat2, + const float* __restrict__ scales2, + const float* __restrict__ bias, + scalar_t* __restrict__ buffer, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM, + int64_t block_size_N, + int64_t block_size_K, + int64_t buffer_size_per_thread) { + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* C */ out + mb_start * out_strideM + nb_start, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ scale_ptr, + /* bias */ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const at::Float8_e4m3fn* __restrict__ B, \ + TYPE* __restrict__ C, \ + TYPE* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const float* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + std::vector block_size, std::optional& bias, + at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales2 to be float32."); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + TORCH_CHECK(block_size.size() == 2, + "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); + + int64_t block_size_N = block_size[0]; + int64_t block_size_K = block_size[1]; + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); + TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); + CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); + CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, + "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales to be float32."); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { + fp8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + bias_data, + buffer.data_ptr(), + M, + N, + K, + mat1_strideM, + out_strideM, + block_size_N, + block_size_K, + size_per_thread); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm_int8.cpp b/csrc/cpu/sgl-kernels/gemm_int8.cpp new file mode 100644 index 0000000000000..5a0f65a9200d4 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_int8.cpp @@ -0,0 +1,440 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vd0; + __m512 vd1[COLS]; + + // oops! 4x4 spills but luckly we use 4x2 + __m512 vbias[COLS]; + + // [NOTE]: s8s8 igemm compensation in avx512-vnni + // + // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: + // + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // 1) 128 * b is pre-computed when packing B to vnni formats + // 2) a + 128 is fused when dynamically quantize A + // + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vd0 = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if constexpr (has_bias) { + vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); + vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); + } + } + } + + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); + __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); + if constexpr (has_bias) { + vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); + vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); + } else { + vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); + vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); + } + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void int8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const uint8_t* __restrict__ mat1, + const int8_t* __restrict__ mat2, + const float* __restrict__ scales1, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use int32_t for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * K, + /* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ out + mb_start * N + nb_start, + /* Ctmp*/ Ctmp, + /* As */ scales1 + mb_start, + /* Bs */ scales2 + nb_start, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ N, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \ + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +std::tuple per_token_quant_int8_cpu(at::Tensor& A) { + RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector({A})); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); + CHECK_DIM(2, A); + + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + + const auto st = A.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "per_token_quant_int8: expect A to be bfloat16 or half."); + + auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); + auto As = at::empty({M}, A.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { + uint8_t* __restrict__ Aq_data = Aq.data_ptr(); + float* __restrict__ As_data = As.data_ptr(); + const scalar_t* __restrict__ A_data = A.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + }); + return std::make_tuple(Aq, As); +} + +// weight : static, per-channel, symmetric +// activation : dynamic, per-token, symmetric +// +// mat1 : [M, K] +// mat2 : [N, K] +// scales1 : [M] +// scales2 : [N] +// bias : [N] +// out : [M, N] +// +at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales1, at::Tensor& scales2, + std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales1); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales1.numel(), M); + CHECK_EQ(scales2.numel(), N); + + TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); + TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, + "int8_scaled_mm: expect scales to be float32."); + + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { + int8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales1.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} + +// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + int64_t lda = mat1.stride(0); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales2.numel(), N); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, + "int8_scaled_mm_with_quant: expect mat2 to be int8."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "int8_scaled_mm_with_quant: expect scales to be float32."); + + const int64_t buffer_size = M * K + M * sizeof(float); + auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); + const scalar_t* __restrict__ A_data = mat1.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + + int8_scaled_mm_kernel_impl( + out.data_ptr(), + Aq_data, + packed_w.data_ptr(), + As_data, + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp new file mode 100644 index 0000000000000..beeccff783ea0 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -0,0 +1,1330 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// [NOTE]: Fused MoE kernel with AMX +// +// This file contains implementations for +// * `moe_align_block_size` +// * `fused_moe` +// +// The functionality is identical to triton kernel, excepts: +// * fuse silu_and_mul with gemm1, therefore this kernel +// allocates 2 intermediate_caches instead of 3 +// * add `offsets` in `moe_align_block_size` which keeps track +// of starting offset for each M block. this is for keeping +// output of silu_and_mul in sorted order, thus load_A for +// the 2nd gemm would be contiguous, therefore we can directly +// load A from intermediate_cache1. +// +// TODO: +// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2) +// 2. add prefetch for load A which is indexed access +// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1) +// + +template +inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +int moe_align_block_size( + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ topk_ids, + int32_t* __restrict__ total_cnts, + int32_t* __restrict__ cumsums, + int32_t* __restrict__ offsets, + int num_experts, + int numel, + int num_threads) { + + #define T_INDEX(tt) total_cnts + (tt) * num_experts + + // accumulate count of expert ids locally + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + int32_t* __restrict__ local_cnts = T_INDEX(tid + 1); + + for (int i = begin; i < end; ++i) { + local_cnts[topk_ids[i]]++; + } + }); + + using iVec = at::vec::Vectorized; + for (int t = 0; t < num_threads; ++t) { + at::vec::map2( + [](iVec x, iVec y) { return x + y; }, + T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts); + } + + // the last row holds sums of each experts + int32_t* total_cnts_t_1 = T_INDEX(num_threads); + + cumsums[0] = 0; + for (int e = 0; e < num_experts; ++e) { + // accumulate `num_tokens_post_pad`, also as the expert offset + cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M; + + for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) { + expert_ids[k / BLOCK_M] = e; + } + } + int num_tokens_post_pad = cumsums[num_experts]; + + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + // thread tid offsets in `total_cnts` + int32_t* __restrict__ offsets = T_INDEX(tid); + + for (int i = begin; i < end; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t b_offset = cumsums[expert_id]; + int32_t t_offset = offsets[expert_id]; + sorted_ids[b_offset + t_offset] = i; + offsets[expert_id]++; + } + }); + + // debug: the offset for thread t_1 should be identical to t_2 + int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1); + for (int e = 0; e < num_experts; ++e) { + TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]); + } + + // padding value for sorted_ids: numel + auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) { + for (int d = 0; d < BLOCK_M; ++d) { + if (sorted_ids_ptr[d] == numel) { return d; } + } + return BLOCK_M; + }; + + // offsets holds starting offset for each valida M blocks + // shape : [num_token_blocks + 1] + offsets[0] = 0; + const int num_token_blocks = num_tokens_post_pad / BLOCK_M; + at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) { + for (int mb = begin; mb < end; ++mb) { + offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); + } + }); + // TODO: do we need to vecterize this ? + for (int mb = 0; mb < num_token_blocks; ++mb) { + offsets[mb + 1] += offsets[mb]; + } + // debug: the last value of offsets should be `numel` + TORCH_CHECK(offsets[num_token_blocks] == numel); + + return num_tokens_post_pad; +} + +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M * topk, N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * BLOCK_N; + const float* __restrict__ y = input1 + m * BLOCK_N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B0, const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B0, const at::BFloat16* __restrict__ B1, + at::BFloat16* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb0[COLS]; + __m512bh vb1[COLS]; + __m512 vc0[ROWS * COLS]; + __m512 vc1[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_ps(0.f); + vc1[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b0_ptr = reinterpret_cast(B0); + const float* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16)); + vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + _mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = vc0[row * COLS + col + 0]; + Vec x1 = vc0[row * COLS + col + 1]; + Vec y0 = vc1[row * COLS + col + 0]; + Vec y1 = vc1[row * COLS + col + 1]; + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn2::apply( \ + A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, \ + C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]); + + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-2-8 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN2(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN2(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN2(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fused_experts_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul( + ic1 + offset * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +template +void shared_expert_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + //int64_t mb_start = mb * BLOCK_M; + //int64_t mb_size = std::min(M - mb_start, BLOCK_M); + + // A shape [m_size, K] + const scalar_t* A = input + mb * BLOCK_M * K; + + // B shape [K, n_size] in vnni format + const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + silu_and_mul( + ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: output = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A shape [m_size, IC] + const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; + + // B shape [IC, n_size] in vnni format + const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// common checks +static inline void check_moe_scales( + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale) { + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2."); + } +} + +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ + TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ + TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ + TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) + +// hidden_states: [M, K] +// w1: [E, 2N, K] +// w2: [E, K, N] +// topk_weights: [M, topk] +// topk_ids: [M, topk] (int32_t) +// +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_EQ(topk_weights.sizes(), topk_ids.sizes()); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w1); + CHECK_DIM(3, w2); + CHECK_DIM(2, topk_weights); + CHECK_DIM(2, topk_ids); + + CHECK_EQ(topk_ids.scalar_type(), at::kInt); + CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(1) / 2; + int64_t E = w1.size(0); + int64_t topk = topk_weights.size(1); + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), E); + CHECK_EQ(w2.size(1), K); + CHECK_EQ(packed_w1.size(2), packed_K); + CHECK_EQ(packed_w2.size(2), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // NB: worst case is each expert holds a block with remainder of 1 + // 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)] + // 2. expert_ids : [max_num_blocks] + // 3. total_cnts : [T + 1, E] + // 4. cumsums : [E + 1] + // 5. offsets : [max_num_blocks + 1] + // + int num_threads = at::get_num_threads(); + int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1); + int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M); + auto buffer = at::empty( + {max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)}, + topk_ids.options()); + + int32_t* __restrict__ sorted_ids = buffer.data_ptr(); + int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded; + int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks; + int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E; + int32_t* __restrict__ offsets = cumsums + (E + 1); + + // init sorted_ids with `numel` as the padding number + // init expert_ids with `num_experts` + int64_t numel = M * topk; + at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) { + int64_t m_start = begin * BLOCK_M; + int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start); + fill_stub(sorted_ids + m_start, (int32_t)numel, m_size); + fill_stub(expert_ids + begin, (int32_t)E, end - begin); + }); + // zero total_cnts and cumsums + at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) { + fill_stub(total_cnts + begin, 0, end - begin); + }); + + // align experts index + int64_t num_tokens_post_pad = moe_align_block_size( + sorted_ids, expert_ids, topk_ids.data_ptr(), total_cnts, cumsums, offsets, E, numel, num_threads); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M * topk, N] + // 2. intermediate_cache2 : [M * topk, K] + // 3. A_tmp : [T, BLOCK_M * K] + // 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 5. Aq_tmp : [M, K] or [M * topk, N] + // 6. As_tmp : [M * topk] + // + // for fp8 w8a16: + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] + // + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; + } + + auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr())); + scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N; + + if (use_int8_w8a8) { + uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == E * 2 * N); + TORCH_CHECK(w2s.numel() == E * K); + + fused_experts_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else if (use_fp8_w8a16) { + // here we just ignore C_tmp as it is not used + scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); + + CHECK_MOE_SCALES_FP8(1, 2); + fused_experts_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + intermediate_cache2, + A_tmp, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else { + scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + fused_experts_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } + }); + return out_hidden_states; +} + +// shared expert kernel +// +// hidden_states: [M, K] +// w1: [2N, K] +// w2: [K, N] +// fused_experts_out +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional> block_size, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(fused_experts_out); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_DIM(2, hidden_states); + CHECK_DIM(2, w1); + CHECK_DIM(2, w2); + CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes()); + CHECK_EQ(hidden_states.scalar_type(), st); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(0) / 2; + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), K); + CHECK_EQ(packed_w1.size(1), packed_K); + CHECK_EQ(packed_w2.size(1), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M, N] + // 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 3. Aq_tmp : [M, K] or [M, N] + // 4. As_tmp : [M] + // + // for fp8 w8a16: + // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // + int num_threads = at::get_num_threads(); + int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; + } + + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); + float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + + if (use_int8_w8a8) { + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == 2 * N); + TORCH_CHECK(w2s.numel() == K); + + shared_expert_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else if (use_fp8_w8a16) { + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); + + CHECK_MOE_SCALES_FP8(0, 1); + shared_expert_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else { + shared_expert_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } + }); + return out_hidden_states; +} diff --git a/csrc/cpu/sgl-kernels/moe_fp8.cpp b/csrc/cpu/sgl-kernels/moe_fp8.cpp new file mode 100644 index 0000000000000..84a6af267740a --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_fp8.cpp @@ -0,0 +1,502 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "gemm.h" +#include "vec.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + x0 = x0 * weight_vec; + x1 = x1 * weight_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +inline void silu_and_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + + // no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + bVec y = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y); + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + x0 = x0 * y0; + x1 = x1 * y1; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } +} + +} // anonymous namespace + +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_N = div_up(2 * N, block_size_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; + const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + scale_size_N = div_up(K, block_size_N); + scale_size_K = div_up(N, block_size_K); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \ + template void fused_experts_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + tinygemm_kernel( + /* A */ input + mb * BLOCK_M * K, + /* B */ packed_w1 + nb * BLOCK_N * K, + /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(K, BLOCK_N); + scale_size_K = div_up(N, block_size_K); + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ ic1 + mb * BLOCK_M * N, + /* B */ packed_w2 + nb * BLOCK_N * N, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +} + +#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \ + template void shared_expert_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/moe_int8.cpp b/csrc/cpu/sgl-kernels/moe_int8.cpp new file mode 100644 index 0000000000000..89d0fb5d9f3b7 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_int8.cpp @@ -0,0 +1,769 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template <> +inline void copy_stub(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { + // size might be 64x + 32 + std::memcpy(out, input, size * sizeof(uint8_t)); +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +/// gemm for w13 +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 was; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); + __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); + vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); + vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); + }; + Unroll{}(scalec); + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); + Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); + Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); + Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +/// gemm for w2 +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 was; + __m512 vbs[COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } + } + __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni2::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +} // anonymous namespace + +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + + const int64_t stride_e = 2 * N * packed_K; + const int64_t stride_n = packed_K; + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) float As[BLOCK_M]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, Aq_tmp + index * K, K); + As[m] = As_tmp[index]; + } + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * packed_N; + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; + const float* __restrict__ As = As_tmp + offsets[mb]; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ + template void fused_experts_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \ + int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); + +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + const int64_t stride_n = packed_K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // A shape [m_size, K] + const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + const float* As = As_tmp + mb * BLOCK_M; + + // B shape [K, n_size] in vnni format + const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A shape [m_size, IC] + const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; + const float* __restrict__ As = As_tmp + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); +} + +#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ + template void shared_expert_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \ + int64_t M, int64_t N, int64_t K) + +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/vec.h b/csrc/cpu/sgl-kernels/vec.h new file mode 100644 index 0000000000000..87955cfb2922c --- /dev/null +++ b/csrc/cpu/sgl-kernels/vec.h @@ -0,0 +1,308 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +// clang-format off + +#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) +#define CPU_CAPABILITY_AVX512 +#endif + +#include +#include + +namespace { + +using namespace at::vec; + +template , int> = 0> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return at::vec::convert_from_float(a, b); +} + +#if defined(CPU_CAPABILITY_AVX512) + +// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics +// use native instruction for bfloat16->float32 conversion +template <> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); +} + +#define CVT_BF16_TO_FP32(a) \ + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +#define CVT_FP16_TO_FP32(a) \ + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +// this doesn't hanel NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { + // The following conversion is without denorm behavior, that is to say, + // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) + // Min subnorm : S.0000.001 = 2**(−9) + // 0.0019 ~ 0.0137 cannot be converted correctly. + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + auto mask = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_setzero_si512()); // mask = x & 0x7f + auto mask_nan = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_set1_epi16(127)); // mask_nan = x & 0x7f + auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4 + auto exponent = _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), + _mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120) + auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); + nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan + return (__m512bh)(_mm512_or_si512( + nonsign, + _mm512_slli_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(128)), + 8))); // add sign (x & 128) << 8 +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + __m512i lg2mant = _mm512_mask_mov_epi16( + _mm512_mask_mov_epi16( + _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), + _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), + _mm512_set1_epi16(2)); + return (__m512bh)(_mm512_or_si512( + _mm512_maskz_mov_epi16( + _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), + _mm512_mask_blend_epi16( + _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), + _mm512_or_si512( + _mm512_and_si512( + _mm512_sllv_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), + _mm512_set1_epi16(0x007f)), + _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), + _mm512_or_si512( + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), + _mm512_slli_epi16( + _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), + 7)))), + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); +} + +inline __m512bh CVT_FP8_TO_BF16(__m256i a) { +#ifdef SGLANG_CPU_FP8_CVT_FTZ + return cvt_e4m3_bf16_intrinsic_no_nan(a); +#else + return cvt_e4m3_bf16_intrinsic_with_denorm(a); +#endif +} + +#endif + +// vector to scalar reduction +#if defined(CPU_CAPABILITY_AVX512) && 0 +inline float vec_reduce_sum(const Vectorized& a) { + return _mm512_reduce_add_ps(__m512(a)); +} + +inline float vec_reduce_max(const Vectorized& a) { + return _mm512_reduce_max_ps(__m512(a)); +} +#else +inline float vec_reduce_sum(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); +} + +inline float vec_reduce_max(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); +} +#endif + +// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 +template +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { + + float amax = 0.f; // absolute max + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]); + amax = std::max(amax, std::abs(val)); + } + + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]) * inv_scale; + Aq[k] = (uint8_t)(std::round(val)) + 128; + } + As = scale; +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const at::BFloat16* __restrict__ A, int64_t K, float eps) { + + const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m512i off = _mm512_set1_epi32(128); + + // K is 32x, no remainder + float amax = 0.f; + __m512 vamax0 = _mm512_set1_ps(0.f); + __m512 vamax1 = _mm512_set1_ps(0.f); + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); + vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); + } + amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + const __m512 vd = _mm512_set1_ps(inv_scale); + + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + va0 = _mm512_mul_ps(va0, vd); + va1 = _mm512_mul_ps(va1, vd); + va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); + __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); + } + As = scale; +} +#endif + +// transpose utils +// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 +#if defined(CPU_CAPABILITY_AVX512) +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +// transpose from [2, 32] to [32, 2] +inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { + // r0: {a0, a1, ..., a31} + // r1: {b0, b1, ..., b31} + // + // d0: {a0, b0, ..., a15, b15} + // d1: {a16, b16, ..., a31, b31} + // + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + return std::make_tuple(d0, d1); +} +#pragma GCC diagnostic pop + +#endif + +// TODO: debug print, remove me later +template +void print_array(scalar_t* ptr, int size) { + for (int d = 0; d < size; ++d) { + if (d % 16 == 0) { std::cout << std::endl; } + std::cout << ptr[d] << " "; + } + std::cout << std::endl; +} + +} // anonymous namespace diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index f55e96de251d0..9adb6f27ec411 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -7,9 +7,10 @@ namespace { #define MAX_SHM_RANK_NUM 8 -#define MAX_THREAD_NUM 12 -#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) -#define MIN_THREAD_PROCESS_SIZE (8 * 1024) +#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) +static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); +#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) +#define MIN_THREAD_PROCESS_SIZE (256) #define MAX_P2P_SEND_TENSOR_NUM 8 template @@ -32,10 +33,10 @@ struct KernelVecType { using scalar_vec_t = vec_op::FP16Vec16; }; -enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; - struct ThreadSHMContext { - volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; + volatile char _curr_thread_stamp; + volatile char _ready_thread_stamp; + char _padding1[6]; int thread_id; int thread_num; int rank; @@ -44,14 +45,19 @@ struct ThreadSHMContext { int swizzled_ranks[MAX_SHM_RANK_NUM]; void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; + size_t _thread_buffer_mask; + char _padding2[56]; ThreadSHMContext(const int thread_id, const int thread_num, const int rank, const int group_size, void* thread_shm_ptr) - : thread_id(thread_id), + : _curr_thread_stamp(1), + _ready_thread_stamp(0), + thread_id(thread_id), thread_num(thread_num), rank(rank), group_size(group_size), - _spinning_count(0) { + _spinning_count(0), + _thread_buffer_mask(0) { static_assert(sizeof(ThreadSHMContext) % 64 == 0); TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); @@ -60,7 +66,6 @@ struct ThreadSHMContext { shm_contexts[i] = nullptr; thread_shm_ptrs[i] = nullptr; swizzled_ranks[i] = (i + rank) % group_size; - thread_stats[i] = ThreadSHMStat::DONE; } set_context(rank, this, thread_shm_ptr); } @@ -77,59 +82,66 @@ struct ThreadSHMContext { template T* get_thread_shm_ptr(int rank) { - return reinterpret_cast(thread_shm_ptrs[rank]); + return reinterpret_cast( + reinterpret_cast(thread_shm_ptrs[rank]) + + (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); + } + + void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } + + char get_curr_stamp() const { return _curr_thread_stamp; } + + char get_ready_stamp() const { return _ready_thread_stamp; } + + void next_stamp() { + _mm_mfence(); + _curr_thread_stamp += 1; + } + + void commit_ready_stamp() { + _mm_mfence(); + _ready_thread_stamp = _curr_thread_stamp; } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } - void wait_for_all(ThreadSHMStat prev_stat) { - for (int idx = 0; idx < group_size; ++idx) { + template + void wait_for_all(Cond&& cond) { + for (int idx = 1; idx < group_size; ++idx) { int rank = get_swizzled_rank(idx); - while (thread_stats[rank] == prev_stat) { - ++_spinning_count; - _mm_pause(); - } + wait_for_one(rank, std::forward(cond)); } - vec_op::mem_barrier(); } - void wait_for_one(int rank, ThreadSHMStat prev_stat) { - while (thread_stats[rank] == prev_stat) { + template + void wait_for_one(int rank, Cond&& cond) { + ThreadSHMContext* rank_ctx = shm_contexts[rank]; + for (;;) { + char local_curr_stamp = get_curr_stamp(); + char local_ready_stamp = get_ready_stamp(); + char rank_curr_stamp = rank_ctx->get_curr_stamp(); + char rank_ready_stamp = rank_ctx->get_ready_stamp(); + if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, + rank_ready_stamp)) { + break; + } ++_spinning_count; _mm_pause(); } - vec_op::mem_barrier(); } - void set_thread_stat(ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[this->rank] = stat; - } + static bool check_no_buffer_conflict(char local_curr_stamp, + char local_ready_stamp, + char rank_curr_stamp, + char rank_ready_stamp) { + char temp = rank_curr_stamp + 2; + return local_curr_stamp != temp; } - void set_thread_stat(int target_rank, ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[target_rank] = stat; - } - } - - // barrier for all ranks in the group, used for all2all ops - // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... - void barrier(ThreadSHMStat next_stat) { - if (next_stat == ThreadSHMStat::THREAD_READY) { - set_thread_stat(ThreadSHMStat::THREAD_READY); - wait_for_all(ThreadSHMStat::DONE); - } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { - set_thread_stat(ThreadSHMStat::SHM_DATA_READY); - wait_for_all(ThreadSHMStat::THREAD_READY); - } else if (next_stat == ThreadSHMStat::DONE) { - set_thread_stat(ThreadSHMStat::DONE); - wait_for_all(ThreadSHMStat::SHM_DATA_READY); - } else { - TORCH_CHECK(false, "Invalid next_stat to barrier."); - } + static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, + char rank_curr_stamp, char rank_ready_stamp) { + char temp = local_curr_stamp + 1; + return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); } std::string to_string() const { @@ -164,7 +176,7 @@ class SHMManager { const int group_size) : _rank(rank), _group_size(group_size), - _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), + _thread_num(torch::get_num_threads()), _shm_names({""}), _shared_mem_ptrs({nullptr}), _shm_ctx(nullptr) { @@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { (total_units_num + thread_num - 1) / thread_num; int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); int64_t max_per_thread_iteration_elem_num = - PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); + (PER_THREAD_SHM_BUFFER_BYTES >> 1) / + sizeof(scalar_t); // Note: double buffer int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; #pragma omp parallel for schedule(static, 1) @@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { int64_t curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); ThreadSHMContext* thread_ctx = ctx + i; + bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); while (curr_elem_num > 0) { - inner_func(thread_ctx, offset, curr_elem_num); + inner_func(thread_ctx, offset, curr_elem_num, fast_mode); + thread_ctx->next_stamp(); + thread_ctx->next_buffer(); offset += max_per_thread_iteration_elem_num; curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); } @@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); @@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, thread_ctx->get_swizzled_rank(idx + 1)); }); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, thread_data_elem_num); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); - + thread_ctx->commit_ready_stamp(); int64_t aligned_data_elem_num = (data_elem_num / vec_elem_num) * vec_elem_num; int64_t i = 0; + thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); #pragma GCC unroll 4 for (; i < aligned_data_elem_num; i += vec_elem_num) { vec_t local_data(thread_data_ptr + i); // load from cache @@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, reduced_data.save(thread_data_ptr + i, data_elem_num - aligned_data_elem_num); } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); - - shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, - data_elem_num * sizeof(scalar_t)); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } + shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, + data_elem_num * sizeof(scalar_t)); + thread_ctx->commit_ready_stamp(); if (rank == dst) { shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, data_elem_num * sizeof(scalar_t)); @@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, scalar_t* src_ptr = thread_ctx->get_thread_shm_ptr(src_rank); // shm scalar_t* dst_ptr = outputs[src_rank] + data_offset; - shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, - data_elem_num * sizeof(scalar_t)); + thread_ctx->wait_for_one(src_rank, + ThreadSHMContext::check_stamp_ready); + shm_cc_ops::memcpy(dst_ptr, src_ptr, + data_elem_num * sizeof(scalar_t)); } } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -599,7 +614,7 @@ struct TensorListMeta { int8_t _padding[40]; }; -void shm_send_tensor_list_impl(ThreadSHMContext* ctx, +void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, const std::vector& tensor_list) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) std::vector tensor_list_with_metadata; @@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata->total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; - // Wait until the receiver set the stat to DONE - thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); - int64_t curr_shm_offset = 0; + thread_ctx->wait_for_one(dst, + ThreadSHMContext::check_no_buffer_conflict); while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); @@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, frag.ptr, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); + thread_ctx->commit_ready_stamp(); }); } @@ -646,8 +659,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, torch::Tensor metadata_tensor = torch::empty({sizeof(TensorListMeta)}, options); - // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY - ctx->wait_for_one(src, ThreadSHMStat::DONE); + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); shm_cc_ops::memcpy(metadata_tensor.data_ptr(), ctx->get_thread_shm_ptr(src), sizeof(TensorListMeta)); @@ -664,9 +676,8 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata.total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { - // Wait until the sender set the stat to SHM_DATA_READY - thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); + int64_t data_elem_num, bool fast_mode) { + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); int64_t curr_shm_offset = 0; while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); @@ -677,8 +688,6 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); }); std::vector tensor_list; @@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, int64_t dst) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list) shm_send_tensor_list_impl( - SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, + tensor_list); CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) } @@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { TORCH_CHECK(shm_manager); shm_manager->join(name); return shm_manager->get_shm_ctx()->to_string(); -} \ No newline at end of file +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 60304d229a8f5..ebfc81f858367 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, std::vector shm_recv_tensor_list(int64_t handle, int64_t src); +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, + bool is_vnni); + +at::Tensor convert_weight_packed(at::Tensor& weight); + +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, + at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, + bool use_int8_w8a8, bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); + +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, bool is_vnni); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -214,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", &shm_recv_tensor_list); #endif + + // sgl-kernels +#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) + ops.def( + "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " + "bias, bool is_vnni) -> Tensor"); + ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); + ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); + ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + ops.def( + "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " + "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " + "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " + "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " + "Tensor"); + ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); + ops.def( + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " + "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); + ops.impl("int8_scaled_mm_with_quant", torch::kCPU, + &int8_scaled_mm_with_quant); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 370b854def0f3..5f2d0dbe27d34 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -118,6 +118,7 @@ vLLM CPU backend supports the following vLLM features: - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`. - `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). +- `VLLM_CPU_SGL_KERNEL` (Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). ## Performance tips diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index f656f90c4bd37..7d7a62eec118a 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -78,7 +78,7 @@ AITER_MODEL_LIST = [ ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -87,6 +87,7 @@ AITER_MODEL_LIST = [ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 51900de1cc099..36a0395ccdc93 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1850,3 +1850,52 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale) return out + + +if hasattr(torch.ops._C, "weight_packed_linear"): + + @register_fake("_C::weight_packed_linear") + def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool) -> torch.Tensor: + return torch.empty((mat1.size(0), mat2.size(0)), + dtype=mat1.dtype, + device=mat2.device) + + +if hasattr(torch.ops._C, "fused_experts_cpu"): + + @register_fake("_C::fused_experts_cpu") + def fused_experts_cpu_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, + use_int8_w8a8: bool, + use_fp8_w8a16: bool, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + block_size: Optional[list[int]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): + + @register_fake("_C::int8_scaled_mm_with_quant") + def int8_scaled_mm_with_quant_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + scales2: torch.Tensor, + bias: Optional[torch.Tensor], + out_dtype: torch.dtype, + is_vnni: bool, + ) -> torch.Tensor: + M = mat1.size(0) + N = mat2.size(0) + return torch.empty((M, N), dtype=out_dtype) diff --git a/vllm/envs.py b/vllm/envs.py index a3f19c7ee5c70..c73dbb0a8446f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 VLLM_CPU_MOE_PREPACK: bool = True + VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 @@ -447,6 +448,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), + # (CPU backend only) whether to use SGL kernels, optimized for small batch. + "VLLM_CPU_SGL_KERNEL": + lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py new file mode 100644 index 0000000000000..68ce6bcccb5d4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch + +from vllm import envs + + +class IPEXFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + +class SGLFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + @staticmethod + def _grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use + # biased scores for expert selection but original scores for + # routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, + -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, + keepdim=True) + + return topk_weights, topk_ids.to(torch.int32) + + @staticmethod + def _select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = SGLFusedMOE._grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_ids = topk_ids.to(torch.int32) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = SGLFusedMOE._select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + torch.ops._C.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + True, + False, + False, + None, + None, + None, + None, + None, + True, + ) + return x diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e6f555d315d8e..d6ead084af99c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -550,12 +550,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - import intel_extension_for_pytorch as ipex - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + dtype = layer.w13_weight.dtype + if (envs.VLLM_CPU_SGL_KERNEL + and torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: raise NotImplementedError("CPU MOE only supports x86 arch.") @@ -673,13 +684,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", apply_router_weight_on_input: bool = False, + activation: str = "silu", **kwargs, ): - assert activation == "silu", f"{activation} is not supported." - assert apply_router_weight_on_input is False - return layer.ipex_fusion( + return layer.cpu_fused_moe( + layer, x, use_grouped_topk, top_k, @@ -687,9 +697,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize, topk_group, num_expert_group, + global_num_experts, + expert_map, custom_routing_function, scoring_func, e_score_correction_bias, + apply_router_weight_on_input, + activation, ) def forward_hpu( @@ -764,7 +778,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, renormalize=renormalize) - forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + else: + forward_native = forward_cuda def determine_expert_map( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb1832..a05ae0edbd775 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -27,6 +28,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -195,12 +197,33 @@ class UnquantizedLinearMethod(LinearMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if (torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16 and N % 32 == 0 + and K % 32 == 0): + packed_weight = torch.ops._C.convert_weight_packed( + layer.weight) + assert packed_weight.size() == layer.weight.size() + layer.weight.copy_(packed_weight) + if layer.bias is not None: + layer.bias = Parameter(layer.bias.to(torch.float32), + requires_grad=False) + layer.use_cpu_sgl = True + else: + logger.warning( + "CPU SGL kernels require Intel AMX support," + " bfloat16 weight, IC and OC are divisible by 32.") + layer.use_cpu_sgl = False + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 41b5253dca048..ad4ba9c0b827a 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -63,7 +63,15 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def rocm_unquantized_gemm(x: torch.Tensor, +def default_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return torch.nn.functional.linear(x, weight, bias) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): from vllm.platforms.rocm import on_gfx9 @@ -89,7 +97,20 @@ def rocm_unquantized_gemm(x: torch.Tensor, return torch.nn.functional.linear(x, weight, bias) +def cpu_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + if getattr(layer, "use_cpu_sgl", False): + return torch.ops._C.weight_packed_linear(x, weight, bias, True) + else: + return torch.nn.functional.linear(x, weight, bias) + + def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): return rocm_unquantized_gemm - return torch.nn.functional.linear + elif current_platform.is_cpu(): + return cpu_unquantized_gemm + else: + return default_unquantized_gemm diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9ff3a7a7327d9..f35f969781bd1 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,7 +43,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 106bce162003f..dccd60f4463aa 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -194,6 +194,8 @@ class CpuPlatform(Platform): "epilogue_fusion": True, }) + if compilation_config.use_inductor: + compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION From 08d81f1014d174d4dd96518914c7ed9767c67a3f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 1 Jul 2025 03:29:08 -0400 Subject: [PATCH 025/195] [Bugfix] Fix deepep tests (#20288) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 2 +- tests/kernels/moe/test_deepep_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 475427f439289..008406c3f1593 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -30,7 +30,7 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): import deep_gemm diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 80a36dc39712a..94947c809e3a3 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -31,7 +31,7 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), From b1c1fe35a599cfd3c0404702c65c2381b025bc6a Mon Sep 17 00:00:00 2001 From: Kebe Date: Tue, 1 Jul 2025 15:33:22 +0800 Subject: [PATCH 026/195] [Misc] remove redundant char (#20287) Signed-off-by: Kebe --- benchmarks/benchmark_serving.py | 2 +- vllm/benchmarks/serve.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 886a51e1cbd9a..9b235266dff1a 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -551,7 +551,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": metrics.request_goodput if goodput_config_dict else None, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 419284cca042e..8b16fea9e3d3c 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -498,7 +498,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, From 96453cfa831340788ef72c42bc2a1a2b4496a27f Mon Sep 17 00:00:00 2001 From: TY-AMD Date: Tue, 1 Jul 2025 16:12:19 +0800 Subject: [PATCH 027/195] [BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067) Signed-off-by: Tianyuan Wu --- .../attention/test_attention_selector.py | 6 +- .../attention/test_rocm_attention_selector.py | 6 +- vllm/platforms/rocm.py | 10 +++- vllm/v1/attention/backends/mla/common.py | 9 ++- vllm/v1/attention/backends/mla/triton_mla.py | 57 +++++++++++++++++++ 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f3e64155703c2..a8ed749ba13b5 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -106,10 +106,8 @@ def test_env( block_size, False, use_mla=use_mla) - if use_v1 and name != "TRITON_MLA": - assert backend.get_name() == f"{name}_VLLM_V1" - else: - assert backend.get_name() == name + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index ed58880cc9e6c..34311b9ccd767 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, None) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 08d471d5a983c..ee53a76ceb6db 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -186,8 +186,14 @@ class RocmPlatform(Platform): if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + if use_v1: + logger.info_once( + "Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1878ae74dbc6f..d45ec04472a69 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj - self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + attn_out = self.flash_attn_varlen_func( q=q, k=k, v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, **kwargs, ) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index e26d7909184b5..99938f22f108c 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,10 +5,14 @@ from typing import Any, Optional import torch +from vllm import envs from vllm.attention.backends.abstract import (AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) @@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): raise NotImplementedError( "TritonMLA V1 with FP8 KV cache not yet supported") + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + self.triton_fa_func = triton_attention if HAS_TRITON else None + + def _flash_attn_varlen_diff_headdims_rocm(self, + q, + k, + v, + softmax_scale=None, + **kwargs): + assert self.triton_fa_func is not None + + # Triton Attention requires a padded V + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + # The output of triton_attention is a tuple of + # [output_tensor, encoded_softmax] where encoded_softmax is always None + output_tensor, _ = self.triton_fa_func( + q, + k, + padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + + return output_tensor + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + if current_platform.is_rocm() \ + and self.use_triton_flash_attn \ + and not return_softmax_lse: + return self._flash_attn_varlen_diff_headdims_rocm( + q, k, v, softmax_scale=softmax_scale, **kwargs) + else: + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs) + def _forward_decode( self, q_nope: torch.Tensor, From 787b13389e2c0b114074f0a0d715eeb6c0a2b0c5 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Tue, 1 Jul 2025 16:18:09 +0800 Subject: [PATCH 028/195] [doc] fix the incorrect logo in dark mode (#20289) Signed-off-by: reidliu41 --- docs/README.md | 3 ++- docs/mkdocs/stylesheets/extra.css | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index 9fb3137b31928..e1d1046951a59 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,7 +1,8 @@ # Welcome to vLLM
- ![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM" class="no-scaled-link" width="60%" } + ![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM Light" class="logo-light" width="60%" } + ![](./assets/logos/vllm-logo-text-dark.png){ align="center" alt="vLLM Dark" class="logo-dark" width="60%" }

diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index 248711f491b9d..892013c1cddfa 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -134,3 +134,12 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . opacity: 0.9; transform: translateY(2px); } + +/* For logo css */ +[data-md-color-scheme="default"] .logo-dark { + display: none; +} + +[data-md-color-scheme="slate"] .logo-light { + display: none; +} From c05596f1a350f3d993c467959ed02492141c2527 Mon Sep 17 00:00:00 2001 From: Lionel Villard Date: Tue, 1 Jul 2025 05:10:28 -0400 Subject: [PATCH 029/195] [Perf] Validate @config in pre-commit instead of dynamically (#20200) Signed-off-by: Lionel Villard --- .pre-commit-config.yaml | 7 ++ tests/test_config.py | 35 +----- tests/tools/__init__.py | 0 tests/tools/test_config_validator.py | 49 +++++++++ tools/validate_config.py | 158 +++++++++++++++++++++++++++ vllm/config.py | 28 +---- 6 files changed, 220 insertions(+), 57 deletions(-) create mode 100644 tests/tools/__init__.py create mode 100644 tests/tools/test_config_validator.py create mode 100644 tools/validate_config.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15ef5defff69e..d962252eb3dd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -160,6 +160,13 @@ repos: types: [python] pass_filenames: false additional_dependencies: [pathspec, regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/validate_config.py + language: python + types: [python] + pass_filenames: true + files: vllm/config.py|tests/test_config.py # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/tests/test_config.py b/tests/test_config.py index 5d5c4453d30d2..cb7654c26afc8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,49 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import MISSING, Field, asdict, dataclass, field -from typing import Literal, Union import pytest from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - config, get_field) + get_field) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform -class _TestConfig1: - pass - - -@dataclass -class _TestConfig2: - a: int - """docstring""" - - -@dataclass -class _TestConfig3: - a: int = 1 - - -@dataclass -class _TestConfig4: - a: Union[Literal[1], Literal[2]] = 1 - """docstring""" - - -@pytest.mark.parametrize(("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), -]) -def test_config(test_config, expected_error): - with pytest.raises(Exception, match=expected_error): - config(test_config) - - def test_compile_config_repr_succeeds(): # setup: VllmBackend mutates the config object config = VllmConfig() diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py new file mode 100644 index 0000000000000..b0475894a114e --- /dev/null +++ b/tests/tools/test_config_validator.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast + +import pytest + +from tools.validate_config import validate_ast + +_TestConfig1 = ''' +@config +class _TestConfig1: + pass +''' + +_TestConfig2 = ''' +@config +@dataclass +class _TestConfig2: + a: int + """docstring""" +''' + +_TestConfig3 = ''' +@config +@dataclass +class _TestConfig3: + a: int = 1 +''' + +_TestConfig4 = ''' +@config +@dataclass +class _TestConfig4: + a: Union[Literal[1], Literal[2]] = 1 + """docstring""" +''' + + +@pytest.mark.parametrize(("test_config", "expected_error"), [ + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), +]) +def test_config(test_config, expected_error): + tree = ast.parse(test_config) + with pytest.raises(Exception, match=expected_error): + validate_ast(tree) diff --git a/tools/validate_config.py b/tools/validate_config.py new file mode 100644 index 0000000000000..8b1e955c653d7 --- /dev/null +++ b/tools/validate_config.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Ensures all fields in a config dataclass have default values +and that each field has a docstring. +""" + +import ast +import inspect +import sys + + +def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + Adapted from https://davidism.com/attribute-docstrings/ + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +class ConfigValidator(ast.NodeVisitor): + + def __init__(self): + ... + + def visit_ClassDef(self, node): + # Validate class with both @config and @dataclass decorators + decorators = [ + id for d in node.decorator_list if (isinstance(d, ast.Name) and ( + (id := d.id) == 'config' or id == 'dataclass')) or + (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and + (id := d.func.id) == 'dataclass')) + ] + + if set(decorators) == {'config', 'dataclass'}: + validate_class(node) + elif set(decorators) == {'config'}: + fail( + f"Class {node.name} with config decorator must be a dataclass.", + node) + + self.generic_visit(node) + + +def validate_class(class_node: ast.ClassDef): + attr_docs = get_attr_docs(class_node) + + for stmt in class_node.body: + # A field is defined as a class variable that has a type annotation. + if isinstance(stmt, ast.AnnAssign): + # Skip ClassVar + # see https://docs.python.org/3/library/dataclasses.html#class-variables + if isinstance(stmt.annotation, ast.Subscript) and isinstance( + stmt.annotation.value, + ast.Name) and stmt.annotation.value.id == "ClassVar": + continue + + if isinstance(stmt.target, ast.Name): + field_name = stmt.target.id + if stmt.value is None: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a default value.", stmt) + + if field_name not in attr_docs: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a docstring.", stmt) + + if isinstance(stmt.annotation, ast.Subscript) and \ + isinstance(stmt.annotation.value, ast.Name) \ + and stmt.annotation.value.id == "Union" and \ + isinstance(stmt.annotation.slice, ast.Tuple): + args = stmt.annotation.slice.elts + literal_args = [ + arg for arg in args + if isinstance(arg, ast.Subscript) and isinstance( + arg.value, ast.Name) and arg.value.id == "Literal" + ] + if len(literal_args) > 1: + fail( + f"Field '{field_name}' in {class_node.name} must " + "use a single " + "Literal type. Please use 'Literal[Literal1, " + "Literal2]' instead of 'Union[Literal1, Literal2]'" + ".", stmt) + + +def validate_ast(tree: ast.stmt): + ConfigValidator().visit(tree) + + +def validate_file(file_path: str): + try: + print(f"validating {file_path} config dataclasses ", end="") + with open(file_path, encoding="utf-8") as f: + source = f.read() + + tree = ast.parse(source, filename=file_path) + validate_ast(tree) + except ValueError as e: + print(e) + SystemExit(2) + else: + print("✅") + + +def fail(message: str, node: ast.stmt): + raise ValueError(f"❌ line({node.lineno}): {message}") + + +def main(): + for filename in sys.argv[1:]: + validate_file(filename) + + +if __name__ == "__main__": + main() diff --git a/vllm/config.py b/vllm/config.py index 46a5bf34f66e4..6412e6e293b45 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -18,7 +18,7 @@ from functools import cached_property from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args, get_origin) + Protocol, TypeVar, Union, cast, get_args) import regex as re import torch @@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT: (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` requires custom construction from CLI (i.e. `CompilationConfig`), it can have a `from_cli` method, which will be called instead. + + Config validation is performed by the tools/validate_config.py + script, which is invoked during the pre-commit checks. """ - if not is_dataclass(cls): - raise TypeError("The decorated class must be a dataclass.") - attr_docs = get_attr_docs(cls) - for f in fields(cls): - if f.init and f.default is MISSING and f.default_factory is MISSING: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a default value." - ) - - if f.name not in attr_docs: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a docstring.") - - if get_origin(f.type) is Union: - args = get_args(f.type) - literal_args = [arg for arg in args if get_origin(arg) is Literal] - if len(literal_args) > 1: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must use a single " - "Literal type. Please use 'Literal[Literal1, Literal2]' " - "instead of 'Union[Literal1, Literal2]'.") return cls @@ -1798,7 +1780,7 @@ class ParallelConfig: eplb_step_interval: int = 3000 """ Interval for rearranging experts in expert parallelism. - + Note that if this is greater than the EPLB window size, only the metrics of the last `eplb_window_size` steps will be used for rearranging experts. """ From 9025a9a7050253678431b2c20e6dd0be55f0dcc2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 06:20:34 -0400 Subject: [PATCH 030/195] [Quant] [Bugfix] Fix quantization config matching with `hf_to_vllm_mapper` (#20046) --- .../test_register_quantization_config.py | 1 + vllm/lora/models.py | 2 +- vllm/lora/worker_manager.py | 5 +--- .../layers/quantization/base_config.py | 13 +++++++++++ .../layers/quantization/bitblas.py | 1 + .../compressed_tensors/compressed_tensors.py | 17 +++++++++++++- .../model_executor/layers/quantization/fp8.py | 10 +++++++- .../layers/quantization/gptq_bitblas.py | 1 + .../layers/quantization/marlin.py | 2 ++ .../layers/quantization/modelopt.py | 1 + .../layers/quantization/torchao.py | 1 + vllm/model_executor/model_loader/utils.py | 22 ++++++++++-------- vllm/model_executor/models/interfaces.py | 23 ++++++++++++++++--- vllm/model_executor/models/qwen2_5_vl.py | 14 +++++------ vllm/model_executor/models/transformers.py | 1 + vllm/model_executor/models/utils.py | 15 +++++++++++- vllm/model_executor/utils.py | 7 ++++-- 17 files changed, 107 insertions(+), 29 deletions(-) diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 42081a8c68cdc..6c541fdbeeae2 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig): def __init__(self, num_bits: int = 8) -> None: """Initialize the quantization config.""" + super().__init__() self.num_bits = num_bits def get_name(self) -> QuantizationMethods: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 262e6799583ae..9e1ed3a771798 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -805,7 +805,7 @@ def create_lora_manager( lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not hasattr(model, "packed_modules_mapping"): + if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7da44569f4086..7a4af74cbeb12 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -111,10 +111,7 @@ class WorkerLoRAManager(AbstractWorkerManager): # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. model = self._adapter_manager.model - hf_to_vllm_mapper = None - if (hasattr(model, "hf_to_vllm_mapper") - and model.hf_to_vllm_mapper is not None): - hf_to_vllm_mapper = model.hf_to_vllm_mapper + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) lora = self._lora_model_cls.from_local_checkpoint( lora_path, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 78c5c75c06515..4a43351260e9f 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -10,6 +10,7 @@ from torch import nn if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper else: QuantizationMethods = str @@ -149,3 +150,15 @@ class QuantizationConfig(ABC): def get_cache_scale(self, name: str) -> Optional[str]: return None + + def apply_vllm_mapper( # noqa: B027 + self, hf_to_vllm_mapper: "WeightsMapper"): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure + + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + # TODO (@kylesayrs): add implementations for all subclasses + pass diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 9e5ce39ec8f2e..aa8eee88a9f9e 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -63,6 +63,7 @@ class BitBLASConfig(QuantizationConfig): # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f87b2a44f0ac..e7f65d13181d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import suppress -from typing import Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch from compressed_tensors.config import (CompressionFormat, @@ -37,6 +37,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import cutlass_fp4_supported) from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) __all__ = ["CompressedTensorsLinearMethod"] @@ -80,6 +83,18 @@ class CompressedTensorsConfig(QuantizationConfig): def get_name(self) -> QuantizationMethods: return "compressed-tensors" + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( + self.sparsity_scheme_map) + self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( + self.sparsity_ignore_list) + if self.kv_cache_scheme is not None: + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( + self.kv_cache_scheme) + def get_quant_method( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 93472207fbb86..60df679a74bda 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -39,6 +39,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) @@ -100,6 +103,11 @@ class Fp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return [] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.ignored_layers is not None: + self.ignored_layers = hf_to_vllm_mapper.apply_list( + self.ignored_layers) + @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 78e0f59fa4bee..caeb266d0b933 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -81,6 +81,7 @@ class GPTQBitBLASConfig(QuantizationConfig): # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 62667db26b669..18d1c13373df9 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -32,6 +32,8 @@ class MarlinConfig(QuantizationConfig): group_size: int, lm_head_quantized: bool, ) -> None: + super().__init__() + # Group size for the quantization. self.group_size = group_size self.lm_head_quantized = lm_head_quantized diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e35db5b31dba7..a10911b84afc4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -181,6 +181,7 @@ class ModelOptNvFp4Config(QuantizationConfig): exclude_modules: list[str], group_size: int = 16, ) -> None: + super().__init__() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index a4e0356c02689..63b2ab6bab063 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -55,6 +55,7 @@ class TorchAOConfig(QuantizationConfig): os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ + super().__init__() self.torchao_config = torchao_config self.skip_modules = skip_modules or [] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 79e6fa7b16dc7..159e7b1e6b01a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig, Note that model attributes are passed by reference to quant_config, enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + + Once the `SupportsQuant` mixin has been added to all models, this + function can be removed """ - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None: - # pass packed_modules_mapping by reference to quant_config - quant_config.packed_modules_mapping = packed_mapping - else: - logger.warning( - "The model class %s has not defined `packed_modules_mapping`, " - "this may lead to incorrect mapping of quantized or ignored " - "modules", model_class.__name__) + if not issubclass(model_class, SupportsQuant): + hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None) + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + + # pass mappings by reference to quant_config + if hf_to_vllm_mapper is not None: + quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + if packed_mapping is not None: + quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ad59fe79edcb1..d234632ef1b75 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -18,6 +18,7 @@ from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool: class SupportsQuant: """The interface required for all models that support quantization.""" - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None + packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None quant_config: Optional[QuantizationConfig] = None def __new__(cls, *args, **kwargs) -> Self: instance = super().__new__(cls) + + # find config passed in arguments quant_config = cls._find_quant_config(*args, **kwargs) if quant_config is not None: + + # attach config to model for general use instance.quant_config = quant_config - instance.quant_config.packed_modules_mapping.update( - cls.packed_modules_mapping) + + # apply model mappings to config for proper config-model matching + # NOTE: `TransformersForCausalLM` is not supported due to how this + # class defines `hf_to_vllm_mapper` as a post-init `@property`. + # After this is fixed, get `instance.hf_to_vllm_mapper` directly + if getattr(instance, "hf_to_vllm_mapper", None) is not None: + instance.quant_config.apply_vllm_mapper( + instance.hf_to_vllm_mapper) + if getattr(instance, "packed_modules_mapping", None) is not None: + instance.quant_config.packed_modules_mapping.update( + instance.packed_modules_mapping) + return instance @staticmethod def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + """Find quant config passed through model constructor args""" from vllm.config import VllmConfig # avoid circular import args_values = list(args) + list(kwargs.values()) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff53a2775e3d4..1b64b61a1e5cf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) + SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) @@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): info=Qwen2_5_VLProcessingInfo, dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + SupportsLoRA, SupportsPP, + SupportsQuant): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=self._maybe_ignore_quant_config(self.quant_config), prefix=maybe_prefix(prefix, "visual"), ) @@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + if isinstance(config, (GPTQConfig, GPTQMarlinConfig)): return None - return quant_config + return config def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 2f78d9d4cc065..04ee3a454f9d8 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -467,6 +467,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, # this makes thing complicated. We need to remove this mapper after refactor # `TransformersModel` in the future. + # NOTE: `SupportsQuant` can be updated after property decorator is removed @property def hf_to_vllm_mapper(self): prefix_mapper = { diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index aa88f42101605..62deb68035b92 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Callable, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -64,6 +64,19 @@ class WeightsMapper: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) + def apply_list(self, values: list[str]) -> list[str]: + return [ + out_name for name in values + if (out_name := self._map_name(name)) is not None + ] + + def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]: + return { + out_name: value + for name, value in values.items() + if (out_name := self._map_name(name)) is not None + } + class AutoWeightsLoader: """ diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index cbaa34bfc30b2..2b20ca2a3ba3f 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -58,7 +58,8 @@ def _make_synced_weight_loader(original_weight_loader): def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: - parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {})) + parent_map = getattr(model, "packed_modules_mapping", None) + parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} # don't infer mapping if the model has defined it explicitly. if parent_map: @@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: # We only check main components instead of whole model submodules for child in model.children(): - child_map = getattr(child, "packed_modules_mapping", {}) + child_map = getattr(child, "packed_modules_mapping", None) + child_map = copy.deepcopy(child_map) if child_map is not None else {} + if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()): raise ValueError( From 650d5dbd04e92f5043a11e4a4d86d4f39ee1b694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 1 Jul 2025 13:40:14 +0200 Subject: [PATCH 031/195] [Misc] Minor refactor of NIXL background handshake (#20068) Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7a077dce7706c..56ae1acf8571f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -515,6 +515,33 @@ class NixlConnectorWorker: # Remote rank -> agent name. return {p_remote_rank: handshake(path, p_remote_rank)} + def _background_nixl_handshake(self, req_id: str, + remote_engine_id: EngineId, meta: ReqMeta): + # Do NIXL handshake in background and add to _ready_requests when done. + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, meta.remote_port, + meta.tp_size) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -902,37 +929,14 @@ class NixlConnectorWorker: remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) if remote_engine_id not in self._remote_agents: - # Being optimistic to assume engine is usually ready, apply - # lock only when the optimistic check fails. + # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - fut = self._handshake_futures.get(remote_engine_id) - if fut is None: - fut = self._handshake_initiation_executor.submit( - self._nixl_handshake, meta.remote_host, - meta.remote_port, meta.tp_size) - self._handshake_futures[remote_engine_id] = fut - - def done_callback(f: Future[dict[int, str]], - eid=remote_engine_id): - with self._handshake_lock: - del self._handshake_futures[eid] - try: - self._remote_agents[eid] = f.result() - except Exception: - logger.exception( - "Handshake with %s failed", eid) - - fut.add_done_callback(done_callback) - - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. - def request_ready(_f: Future[Any], - entry=(req_id, meta)): - self._ready_requests.put(entry) - - fut.add_done_callback(request_ready) + self._background_nixl_handshake( + req_id, remote_engine_id, meta) continue + + # Handshake already completed, start async read xfer. self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. From ed70f3c64f684750edea087e286cbf264e7cc3f3 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 1 Jul 2025 20:48:26 +0800 Subject: [PATCH 032/195] Add GLM4.1V model (Draft) (#19331) Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- docs/models/supported_models.md | 3 +- examples/offline_inference/vision_language.py | 40 +- tests/entrypoints/openai/test_video.py | 2 +- .../multimodal/generation/test_common.py | 28 + .../generation/vlm_utils/custom_inputs.py | 20 + .../generation/vlm_utils/model_utils.py | 24 + .../multimodal/processing/test_common.py | 24 + tests/models/registry.py | 1 + tests/multimodal/test_utils.py | 4 +- vllm/assets/video.py | 26 +- vllm/entrypoints/chat_utils.py | 4 + .../model_executor/layers/rotary_embedding.py | 119 ++ vllm/model_executor/models/glm4_1v.py | 1589 +++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/multimodal/inputs.py | 8 +- vllm/multimodal/parse.py | 42 +- vllm/multimodal/video.py | 27 +- 17 files changed, 1946 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/models/glm4_1v.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0248700292ae2..db650b37a38db 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -553,6 +553,7 @@ Specified using `--task generate`. | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | @@ -583,7 +584,7 @@ Specified using `--task generate`. | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 57b042ed013b1..b9e8bef26eb24 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -248,6 +248,42 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.1V +def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "THUDM/GLM-4.1V-9B-Thinking" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1114,6 +1150,7 @@ model_example_map = { "fuyu": run_fuyu, "gemma3": run_gemma3, "glm4v": run_glm4v, + "glm4_1v": run_glm4_1v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, @@ -1172,10 +1209,11 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays + metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": video, + "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, "questions": vid_questions, } diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 990ea3579291d..b68e08556ee96 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -50,7 +50,7 @@ async def client(server): @pytest.fixture(scope="session") def base64_encoded_video() -> dict[str, str]: return { - video_url: encode_video_base64(fetch_video(video_url)) + video_url: encode_video_base64(fetch_video(video_url)[0]) for video_url in TEST_VIDEO_URLS } diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 9d63339737ce6..6ecf6db56cb39 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -309,6 +309,34 @@ VLM_TEST_SETTINGS = { num_logprobs=10, marks=[large_gpu_mark(min_gb=32)], ), + "glm4_1v": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501 + max_model_len=2048, + max_num_seqs=2, + get_stop_token_ids=lambda tok: [151329, 151336, 151338], + num_logprobs=10, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + auto_cls=AutoModelForImageTextToText, + ), + "glm4_1v-video": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + # GLM4.1V require include video metadata for input + test_type=VLMTestType.CUSTOM_INPUTS, + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + patch_hf_runner=model_utils.glm4_1v_patch_hf_runner, + custom_test_opts=[CustomTestOptions( + inputs=custom_inputs.video_with_metadata_glm4_1v(), + limit_mm_per_prompt={"video": 1}, + )], + # This is needed to run on machine with 24GB VRAM + vllm_runner_kwargs={"gpu_memory_utilization": 0.95}, + ), "h2ovl": VLMTestInfo( models = [ "h2oai/h2ovl-mississippi-800m", diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index aa5835243e042..c53243b42e384 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -129,3 +129,23 @@ def windows_attention_image_qwen2_5_vl(): wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) + + +def video_with_metadata_glm4_1v(): + video_array = VIDEO_ASSETS[0].np_ndarrays + metadata = VIDEO_ASSETS[0].metadata + question = "Describe the video." + video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" + formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + + scales = [0.1, 0.2, 0.25] + video_input = [[(rescale_video_size(video_array, scale), metadata)] + for scale in scales] + prompts = [formatted_prompt] * len(video_input) + + return [ + PromptWithMultiModalInput( + prompts=prompts, + video_data=video_input, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index af4c72f44b676..c1a2aa0dcafbb 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -16,9 +16,11 @@ import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig, GenerationMixin) +from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side +from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -373,6 +375,28 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for GLM4.1V.""" + hf_processor = hf_model.processor + + def processor(*args, videos=None, **kwargs): + if videos is not None and is_list_of(videos, tuple): + # If videos is a list of tuples, we assume each tuple contains + # (video_array, metadata) as in the case of GLM4.1V. + video_metadata = [[VideoMetadata(**video[1])] for video in videos] + videos = [[video[0]] for video in videos] + else: + video_metadata = None + + return hf_processor(*args, + videos=videos, + video_metadata=video_metadata, + **kwargs) + + hf_model.processor = processor + return hf_model + + def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for H2OVL.""" diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 1ba60178c13db..0f33225eda2da 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -24,6 +24,22 @@ from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS +def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for GLM4.1V model. + """ + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + mm_data["video"] = (video, { + "total_num_frames": len(video), + "fps": len(video), + "duration": 1, + "video_backend": "opencv" + }) + return mm_data + + def _test_processing_correctness( model_id: str, hit_rate: float, @@ -154,6 +170,11 @@ _IGNORE_MM_KEYS = { "ultravox": {"audio_features"}, } +MM_DATA_PATCHES = { + # GLM4.1V requires video metadata to be included in the input + "glm4v": glm4_1v_patch_mm_data, +} + def _test_processing_correctness_one( model_config: ModelConfig, @@ -166,6 +187,8 @@ def _test_processing_correctness_one( ): model_type = model_config.hf_config.model_type ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) + if model_type in MM_DATA_PATCHES: + mm_data = MM_DATA_PATCHES[model_type](mm_data) if isinstance(prompt, str): text_prompt = prompt @@ -245,6 +268,7 @@ def _test_processing_correctness_one( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", + "THUDM/GLM-4.1V-9B-Thinking", "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index e56dd19bec670..affe2e88b2b94 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -338,6 +338,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 + "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 5ac0a90f50473..a48542cec3f87 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -172,7 +172,9 @@ async def test_fetch_video_http(video_url: str, num_frames: int): video_sync = connector.fetch_video(video_url, num_frames=num_frames) video_async = await connector.fetch_video_async(video_url, num_frames=num_frames) - assert np.array_equal(video_sync, video_async) + # Check that the video frames are equal and metadata are same + assert np.array_equal(video_sync[0], video_async[0]) + assert video_sync[1] == video_async[1] # Used for the next two tests related to `merge_and_sort_multimodal_metadata`. diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 01834aeeb6c12..16412121cf0a8 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal, Optional import cv2 import numpy as np @@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str, ] +def video_get_metadata(path: str) -> dict[str, Any]: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames / fps if fps > 0 else 0 + + metadata = { + "total_num_frames": total_frames, + "fps": fps, + "duration": duration, + "video_backend": "opencv" + } + return metadata + + VideoAssetName = Literal["baby_reading"] @@ -105,6 +123,12 @@ class VideoAsset: ret = video_to_ndarrays(video_path, self.num_frames) return ret + @property + def metadata(self) -> dict[str, Any]: + video_path = download_video_asset(self.filename) + ret = video_get_metadata(video_path) + return ret + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 35ee52ab4601d..45f1894d022b3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -515,6 +515,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" + if model_type == "glm4v": + return "<|begin_of_image|><|image|><|end_of_image|>" if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" if model_type in ("minicpmo", "minicpmv"): @@ -563,6 +565,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): elif modality == "video": if model_type == "internvl_chat": return "