From 38217877aa70041c0115ee367b75197af9cbc5ad Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Wed, 20 Aug 2025 21:34:49 +0800 Subject: [PATCH 01/57] [Fix] fix offline env use local mode path (#22526) Signed-off-by: rongfu.leng --- .../offline_mode/test_offline_mode.py | 35 +++++++++++++++++++ vllm/engine/arg_utils.py | 10 +++++- vllm/transformers_utils/config.py | 23 ++++++++++-- 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a606eeab5887e..dd8d63ad319ac 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" +import dataclasses import importlib import sys @@ -9,6 +10,7 @@ import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs MODEL_CONFIGS = [ { @@ -108,3 +110,36 @@ def _re_import_modules(): # Error this test if reloading a module failed if reload_exception is not None: raise reload_exception + + +@pytest.mark.skip_global_cleanup +@pytest.mark.usefixtures("cache_models") +def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): + # Set HF to offline mode and ensure we can still construct an LLM + with monkeypatch.context() as m: + try: + m.setenv("HF_HUB_OFFLINE", "1") + m.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + m.setattr( + urllib3.connection.HTTPConnection, + "connect", + disable_connect, + ) + m.setattr( + urllib3.connection.HTTPSConnection, + "connect", + disable_connect, + ) + # Need to re-import huggingface_hub + # and friends to setup offline mode + _re_import_modules() + engine_args = EngineArgs(model="facebook/opt-125m") + LLM(**dataclasses.asdict(engine_args)) + finally: + # Reset the environment after the test + # NB: Assuming tests are run in online mode + _re_import_modules() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 679905aed9ec8..48d9cd08af030 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) +import huggingface_hub import regex as re import torch from pydantic import TypeAdapter, ValidationError @@ -39,7 +40,7 @@ from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import is_interleaved +from vllm.transformers_utils.config import get_model_path, is_interleaved from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -457,6 +458,13 @@ class EngineArgs: # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() + # when use hf offline,replace model id to local model path + if huggingface_hub.constants.HF_HUB_OFFLINE: + model_id = self.model + self.model = get_model_path(self.model, self.revision) + logger.info( + "HF_HUB_OFFLINE is True, replace model_id [%s] " \ + "to model_path [%s]",model_id, self.model) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d8c964fb2a4a4..fe345bd8f0a2e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -14,7 +14,7 @@ from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, + LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig @@ -335,6 +335,7 @@ def maybe_override_with_speculators_target_model( gguf_model_repo = Path(model).parent else: gguf_model_repo = None + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model if gguf_model_repo is None else gguf_model_repo, revision=revision, @@ -400,6 +401,7 @@ def get_config( raise ValueError(error_message) from e if config_format == ConfigFormat.HF: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model, revision=revision, @@ -532,7 +534,7 @@ def try_get_local_file(model: Union[str, Path], revision=revision) if isinstance(cached_filepath, str): return Path(cached_filepath) - except HFValidationError: + except ValueError: ... return None @@ -908,3 +910,20 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_model_path(model: Union[str, Path], revision: Optional[str] = None): + if os.path.exists(model): + return model + assert huggingface_hub.constants.HF_HUB_OFFLINE + common_kwargs = { + "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE, + "revision": revision, + } + + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) + + from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) From 44492358439f612b3934ccd902dbd90fcfa19866 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 20 Aug 2025 22:19:30 +0800 Subject: [PATCH 02/57] [Bugfix] Ensure correctness of HCXVision processing (#23254) Signed-off-by: DarkLight1337 --- .../multimodal/processing/test_common.py | 2 +- .../models/hyperclovax_vision.py | 116 ++++++++---------- 2 files changed, 55 insertions(+), 63 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d5b1de834a618..02aecfad8281d 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -102,7 +102,7 @@ def _test_processing_correctness( partial(random_video, rng, min_frames=2, - max_frames=8, + max_frames=16, min_wh=128, max_wh=256), "audio": diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index d3ddc47ea932f..f8b30d8d98e5c 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -53,6 +53,21 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# Based on combine_frames_into_images in +# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py +def get_num_combined_frames( + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), +) -> int: + max_num_grids = max_grid_shape[0] * max_grid_shape[1] + + # Calculate the number of canvases needed. + num_canvases = num_frames // max_num_grids + leftover_frames = num_frames % max_num_grids + + return num_canvases + (leftover_frames > 0) + + class HCXVisionMultimodalPixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_images: list[torch.Tensor] @@ -172,23 +187,20 @@ class HCXVisionMultiModalProcessor( def replace_multimodal_token( token_ids: torch.Tensor, target_token: int, - repeats: list, + repeats: list[int], ): - output = list() + output = list[int]() _repeats_idx = 0 for token_id in token_ids: if token_id == target_token: - output += [ - token_id.item(), - ] * repeats[_repeats_idx] + output += [token_id.item()] * repeats[_repeats_idx] _repeats_idx += 1 else: - output += [ - token_id.item(), - ] + output += [token_id.item()] + return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", list())): + for video_idx, video_arr in enumerate(mm_data.get("videos", [])): if video_arr.dtype == np.uint8: continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8) @@ -205,88 +217,68 @@ class HCXVisionMultiModalProcessor( if len(mm_data) > 0: # batchify input as a single item images = mm_data.get("images", None) - num_images = 0 - if images is not None: - num_images = len(images) - images = [ - images, - ] # batchify + batched_images = None if images is None else [images] - videos = mm_data.get("videos", - None) # list of video in single conversation - num_videos = 0 - if videos is not None: - num_videos = len(videos) - videos = [ - videos, - ] # batchify + # list of video in single conversation + videos = mm_data.get("videos", None) + batched_videos = None if videos is None else [videos] _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=images, - videos=videos, + images=batched_images, + videos=batched_videos, ), ) # mm-only for k, v in _processed_outputs.items(): - if len(v) < 1: - continue - elif k.endswith("_images"): - # list of list of 4D tensor -> list of 4D tensor + if isinstance(v, list) and len(v) > 0: + assert len(v) == 1 _processed_outputs[k] = v[0] - elif k.endswith("_videos"): - # list of list of 4D tensor -> list of 4D tensor - v = v[0] - if k == "pixel_values_videos": - v = torch.cat(v, dim=0) - _c, _w, _h = v.shape[-3:] - v = v.reshape(num_videos, -1, _c, _w, _h) - v = list(torch.unbind(v, dim=0)) - _processed_outputs[k] = v - if num_images > 0: + if images: tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - IMAGE_TOKEN), + target_token=image_token_id, repeats=_processed_outputs[ "vision_query_lengths_images"], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - if num_videos > 0: - tokenizer = self.info.get_tokenizer() - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - VIDEO_TOKEN), - repeats=_processed_outputs[ - "vision_query_lengths_videos"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - - _ratios = [ - len(_pixel_values) for _pixel_values in - _processed_outputs["pixel_values_videos"] - ] + if videos: _num_per_videos = [ - int(_e / sum(_ratios) * - len(_processed_outputs["vision_query_lengths_videos"])) - for _e in _ratios + get_num_combined_frames(len(video)) for video in videos + ] + _processed_outputs["pixel_values_videos"] = [ + _processed_outputs["pixel_values_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ _processed_outputs["vision_query_lengths_videos"] [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(0, num_videos) + for _i in range(len(videos)) ] + tokenizer = self.info.get_tokenizer() + video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) + processed_outputs["input_ids"] = torch.stack([ + replace_multimodal_token( + token_ids=_input_ids, + target_token=video_token_id, + repeats=[ + sum(lens) for lens in + _processed_outputs["vision_query_lengths_videos"] + ], + ) for _input_ids in processed_outputs["input_ids"] + ], + dim=0) + processed_outputs.update(_processed_outputs) return processed_outputs From b17109beeafbf9577c319ab61530810943a7fc4b Mon Sep 17 00:00:00 2001 From: shixianc <49539556+shixianc@users.noreply.github.com> Date: Wed, 20 Aug 2025 07:35:26 -0700 Subject: [PATCH 03/57] [Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045) Signed-off-by: Shixian Cui --- .../kernels/benchmark_grouped_gemm_cutlass.py | 35 +++- csrc/moe/moe_permute_unpermute_op.cu | 33 ++-- csrc/ops.h | 5 + .../cutlass_w8a8/moe/get_group_starts.cuh | 6 +- .../quantization/cutlass_w8a8/moe/moe_data.cu | 65 +++++-- .../cutlass_w8a8/scaled_mm_entry.cu | 24 +++ csrc/torch_bindings.cpp | 13 ++ tests/kernels/moe/test_cutlass_moe.py | 18 +- .../kernels/moe/test_moe_permute_unpermute.py | 6 +- tests/kernels/moe/test_pplx_cutlass_moe.py | 22 ++- .../quantization/test_cutlass_scaled_mm.py | 2 +- vllm/_custom_ops.py | 22 +++ .../layers/fused_moe/cutlass_moe.py | 179 +++++++++++------- .../layers/fused_moe/moe_permute_unpermute.py | 29 ++- .../compressed_tensors_moe.py | 31 +++ 15 files changed, 369 insertions(+), 121 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99ae9..a6b42406b5cb0 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def bench_run( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def bench_run( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def bench_run( "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 2922352a3f7cc..ca0c873f49d9f 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -45,8 +45,6 @@ void moe_permute( auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); CubKeyValueSorter sorter{}; int64_t* valid_num_ptr = nullptr; @@ -85,12 +83,14 @@ void moe_permute( }); // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); + // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - // update align_expert_first_token_offset + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } @@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& expert_first_token_offset, torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); + TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, + const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, + const std::optional& expert_first_token_offset, int64_t topk, + torch::Tensor& hidden_states) { TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); } @@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 64bcec6ca1527..86fe848e2fd5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847f..15bb2c300543c 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -10,7 +10,7 @@ template __global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int, @@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ static_cast(b_ptrs.data_ptr()), \ static_cast(out_ptrs.data_ptr()), \ @@ -61,6 +61,8 @@ void run_get_group_gemm_starts( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 100f485084444..49cafcc32adc6 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, } } +namespace { +inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& atomic_buffer, + int64_t num_experts, int64_t n, + int64_t k, cudaStream_t stream, + const bool swap_ab) { + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + + const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); + int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + + if (swap_ab) { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } +} +} // namespace + +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( bool may_swap_ab = (!blockscale_offsets.has_value()) && (topk_ids.numel() <= SWAP_AB_THRESHOLD); - if (may_swap_ab) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); if (blockscale_offsets.has_value()) { // fp4 path diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 106bacb4883cb..84843ee6e0949 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90 or 100"); } +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) + get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, + problem_sizes2, num_experts, n, k, + blockscale_offsets); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " + "kernel for CUDA device capability: ", + version_num, ". Required capability: 90 or 100"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7079671c2eb16..3a0ff6eaa7904 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes problem sizes for each expert's multiplication + // used by the two mms called from fused MoE operation. It takes topk_ids as + // an input, and computes problem_sizes1 and problem_sizes2 only. + ops.def( + "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int num_experts, int n, int k, " + " Tensor? blockscale_offsets) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 81fb3ec1de188..c84f66383b902 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, + 'ab_strides1': moe_tensors.ab_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides1': moe_tensors.c_strides1, + 'c_strides2': moe_tensors.c_strides2, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } @@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8( topk_ids[0][1] = 1 workspace13_shape = (m * topk, max(2 * n, k)) - workspace2_shape = (m * topk, n) - output_shape = (m * topk, k) + workspace2_shape = (m * topk, max(n, k)) + output_shape = (m, k) workspace13 = torch.empty(prod(workspace13_shape), device="cuda", @@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, torch.float8_e4m3fn, @@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8( output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, - per_act_token, per_out_channel, False) + a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, + workspace13, workspace2, None, mt.a.dtype, per_act_token, + per_out_channel, False, topk_weights) workspace13.random_() output_random_workspace = torch.empty(output_shape, diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 6ca01f9271bba..d71664d94b9c8 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, atol=0, rtol=0) # check mindice - torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # current kernel usage assumes deepgemm requires align_block_size + # when it's not provided then we don't compute m_indices (for cutlass) + if align_block_size is not None: + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # check permuted_hidden_states, only valid token torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx], diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index f98937ee6c527..98908f2714707 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -76,6 +76,7 @@ def pplx_cutlass_moe( assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape + intermediate_dim = w2.shape[2] num_experts = w1.shape[0] block_size = hidden_dim # TODO support more cases device = pgi.device @@ -124,8 +125,27 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) + ab_strides1 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_local_experts, ), + intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_local_experts, ), + 2 * intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, - out_dtype, per_act_token, per_out_ch) + out_dtype, per_act_token, per_out_ch, + ab_strides1, ab_strides2, c_strides1, + c_strides2) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8730eeaaa761c..a15decdf6f827 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, expert_offsets = torch.zeros((num_experts + 1), device=device, - dtype=torch.int32) + dtype=torch.int64) problem_sizes = torch.zeros((num_experts, 3), device=device, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0d556053f8981..39da08847b2e7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, blockscale_offsets) +def get_cutlass_moe_mm_problem_sizes( + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Compute only the per-expert problem sizes needed by the two grouped matrix + multiplications used in CUTLASS-based fused MoE. + + The function takes in topk_ids (tokenโ†’expert mapping) and computes: + - problem_sizes1, problem_sizes2: Mร—Nร—K sizes of each expert's + multiplication for the two grouped MMs + used in the fused MoE operation. + """ + return torch.ops._C.get_cutlass_moe_mm_problem_sizes( + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, + blockscale_offsets) + + def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0a02b558d09e5..95d23ec0346c1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.scalar_type import scalar_types @@ -34,6 +35,10 @@ def run_cutlass_moe_fp8( w2_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], @@ -41,6 +46,7 @@ def run_cutlass_moe_fp8( per_act_token: bool, per_out_ch: bool, use_batched_format: bool, + topk_weights: Optional[torch.Tensor], ): a1q = hidden_states @@ -99,6 +105,22 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.size(1) local_E = w1.size(0) + if use_batched_format: + mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + act_out = _resize_cache(workspace2, (local_E * padded_M, N)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (local_E * padded_M, N)) + mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) + else: + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M * topk, N)) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + if use_batched_format: assert expert_num_tokens is not None @@ -120,11 +142,10 @@ def run_cutlass_moe_fp8( w2_scale = w2_scale.reshape(w2_scale.size(0), -1) a1q = a1q.reshape(-1, a1q.size(2)) a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() - + # c3x get_group_gemm_starts expects int64 to avoid overflow + # during offset calculations + expert_offsets = expert_offsets.to(torch.int64) else: - expert_offsets = torch.empty((global_num_experts + 1), - dtype=torch.int32, - device=device) problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) @@ -132,84 +153,57 @@ def run_cutlass_moe_fp8( dtype=torch.int32, device=device) - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - if expert_map is not None: - a_map = torch.zeros((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - else: - a_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, a_map, - c_map, global_num_experts, N, K) - - a1q = _fp8_perm(a1q, a_map) - a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + num_expert = global_num_experts if expert_map is None \ + else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm) expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - c_strides1 = torch.full((w1.size(0), ), - 2 * N, - device=device, - dtype=torch.int64) - ab_strides2 = torch.full((w1.size(0), ), - N, - device=device, - dtype=torch.int64) - c_strides2 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - - if use_batched_format: - c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) - c2 = _resize_cache(workspace2, (local_E * padded_M, N)) - c3 = _resize_cache(workspace13, (local_E * padded_M, K)) - else: - c1 = _resize_cache(workspace13, (M * topk, N * 2)) - c2 = _resize_cache(workspace2, (M * topk, N)) - c3 = _resize_cache(workspace13, (M * topk, K)) + ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, + problem_sizes2, + global_num_experts, N, K) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by # random data in the unused workspace. The workspace is unused when # this rank handles only partial tokens, or when it is batched . - c1.fill_(0) + mm1_out.fill_(0) - ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, per_act_token, per_out_ch) - activation_callable(c2, c1) + activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - c2, a2_scale, use_per_token_if_dynamic=per_act_token) + act_out, + a2_scale, + use_per_token_if_dynamic=per_act_token, + output=quant_out) if expert_map is not None: - c3.fill_(0) + mm2_out.fill_(0) - ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) if use_batched_format: - output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) + output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: - # We can't do this inplace because output may point to the same tensor - # as c3. - output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute(out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm) class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): @@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( @@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): block_shape=block_shape, )) self.out_dtype = out_dtype + self.ab_strides1 = ab_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, + a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, + self.c_strides2, workspace13, workspace2, expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, self.per_act_token_quant, self.per_out_ch_quant, - use_batched_format) + use_batched_format, topk_weights) class CutlassExpertsFp8(CutlassExpertsFp8Base): @@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( out_dtype, per_act_token_quant, per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, block_shape, ) @@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() + def workspace_shapes( self, a: torch.Tensor, @@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, N // 2) - output = (M * topk, K) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) @@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( out_dtype, per_act_token_quant, per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, block_shape, ) assert max_experts_per_worker > 0 @@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): assert num_dp is not None workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + max(N // 2, K)) output = (self.max_experts_per_worker, padded_M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) @@ -392,6 +416,10 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, @@ -419,6 +447,17 @@ def cutlass_moe_fp8( Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] + - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. + Shape: [num_experts] + - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. + Shape: [num_experts] + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to @@ -450,6 +489,10 @@ def cutlass_moe_fp8( out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, ), ) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index d9059f50b4459..16a155e718478 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -82,7 +82,8 @@ def moe_permute( n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1 + fill_invalid_expert: int = -1, + permuted_hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -95,14 +96,17 @@ def moe_permute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices + - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. + If None, the output tensor will be created in this function. Returns: - permuted_hidden_states (torch.Tensor): permuted activation. - - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states + - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states + if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. @@ -122,11 +126,16 @@ def moe_permute( 1) // align_block_size * align_block_size if n_local_expert == -1: n_local_expert = n_expert - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if permuted_hidden_states is None: + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( + f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" + f" but got {permuted_hidden_states.size()}") + token_expert_indices = torch.arange(0, n_token * topk, dtype=torch.int32, @@ -153,7 +162,8 @@ def moe_permute( align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) - if a1q_scale is not None: + + if a1q_scale is not None and a1q_scale.dim() > 1: a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] return (permuted_hidden_states, a1q_scale, expert_first_token_offset, @@ -185,6 +195,7 @@ def moe_unpermute( n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, inv_permuted_idx, expert_first_token_offset, topk, out) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8ca8249e694ea..7bc35cd81ac3f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + if self.use_cutlass: + device = layer.w13_weight.device + # ab_strides1 and c_strides2 are the same + self.ab_strides1_c_strides2 = torch.full( + (layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full( + (layer.local_num_experts, ), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full( + (layer.local_num_experts, ), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) @@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, ) self.disable_expert_map = (num_dispatchers > 1 @@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) From 5efd6905bc8469a30664de83bdafaad56aa92903 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 20 Aug 2025 23:42:28 +0800 Subject: [PATCH 04/57] [CLI][Doc] Formalize `--mm-encoder-tp-mode` (#23190) Signed-off-by: DarkLight1337 --- docs/configuration/optimization.md | 45 ++++++++++++++++++++++++ vllm/config/__init__.py | 34 +++++++++++++++++- vllm/config/parallel.py | 4 --- vllm/engine/arg_utils.py | 35 +++++++++++------- vllm/model_executor/models/mllama4.py | 4 +-- vllm/model_executor/models/qwen2_5_vl.py | 3 +- vllm/model_executor/models/step3_vl.py | 3 +- 7 files changed, 104 insertions(+), 24 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index c7f50497d6ffa..db9dfb313fb87 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. +### Batch-level DP for Multi-Modal Encoders + +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. + +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. + +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: + +```python +from vllm import LLM + +llm = LLM( + model="Qwen/Qwen2.5-VL-72B-Instruct", + # Create two EngineCore instances, one per DP rank + data_parallel_size=2, + # Within each EngineCore instance: + # The vision encoder uses TP=4 (not DP=2) to shard the input data + # The language decoder uses TP=4 to shard the weights as usual + tensor_parallel_size=4, + mm_encoder_tp_mode="data", +) +``` + +!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). + +The availablilty of batch-level DP is based on model implementation. +Currently, the following models support `mm_encoder_tp_mode="data"`: + +- Llama4 () +- Qwen2.5-VL () +- Step3 () + ## Input Processing ### Parallel Processing diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 801fa97fe5daf..5b5d477ef066b 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -258,6 +258,7 @@ TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"] +MMEncoderTPMode = Literal["weights", "data"] @config @@ -438,6 +439,19 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -856,8 +870,10 @@ class ModelConfig: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling) + skip_mm_profiling=self.skip_mm_profiling, + ) return None @@ -2547,6 +2563,22 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """ + Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP. + """ + interleave_mm_strings: bool = False """ Enable fully interleaved support for multimodal prompts. diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index bac1e63800d7b..7a9e68f0ea332 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -137,10 +137,6 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" - enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. - Only support LLama4 for now""" - @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 48d9cd08af030..6869c3f23f315 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,12 +28,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs, get_field) + LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -352,6 +352,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False @@ -434,16 +435,14 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_multimodal_encoder_data_parallel: bool = \ - ParallelConfig.enable_multimodal_encoder_data_parallel + # DEPRECATED + enable_multimodal_encoder_data_parallel: bool = False logits_processors: Optional[list[Union[ str, type[LogitsProcessor]]]] = ModelConfig.logits_processors """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - # DEPRECATED - enable_prompt_adapter: bool = False kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill @@ -685,7 +684,8 @@ class EngineArgs: **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", - **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + action="store_true", + deprecated=True) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -735,6 +735,8 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -909,6 +911,14 @@ class EngineArgs: self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + if self.enable_multimodal_encoder_data_parallel: + logger.warning( + "--enable-multimodal-encoder-data-parallel` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-encoder-tp-mode data` instead.") + + self.mm_encoder_tp_mode = "data" + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -947,6 +957,7 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1258,8 +1269,6 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_multimodal_encoder_data_parallel=self. - enable_multimodal_encoder_data_parallel, ) if model_config.is_multimodal_model: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 35103eac8fb56..595bdd17cf2c2 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -728,8 +728,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 34eec10296b50..811ecffcc1e49 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -877,8 +877,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5d41a9e569f53..f8877b584b198 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -882,8 +882,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer( From d6d13bd49ed7fda56ac6a1b0aa53621490c975ac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 20 Aug 2025 09:05:29 -0700 Subject: [PATCH 05/57] [Misc] Add max_seq_len to CommonAttentionMetadata (#23216) Signed-off-by: Woosuk Kwon --- tests/v1/attention/utils.py | 2 ++ tests/v1/spec_decode/test_tree_attention.py | 2 ++ vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/flex_attention.py | 2 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/attention/backends/tree_attn.py | 2 +- vllm/v1/attention/backends/triton_attn.py | 2 +- vllm/v1/attention/backends/utils.py | 6 ++++++ vllm/v1/attention/backends/xformers.py | 2 +- vllm/v1/spec_decode/eagle.py | 1 + vllm/v1/worker/gpu_model_runner.py | 4 ++++ 12 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index a4e38eb32f6a1..e547e71e0cdb7 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -58,6 +58,7 @@ def create_common_attn_metadata( dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) # Create computed tokens (context length for each sequence) context_lens = [ @@ -101,6 +102,7 @@ def create_common_attn_metadata( num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, causal=True, diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d36e4..6317817408661 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -50,6 +50,7 @@ def forward_attention( dtype=torch.int32, ) context_lens = seq_lens - query_lens + max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -81,6 +82,7 @@ def forward_attention( num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ab7a71a399b34..eed3cba9a2ca7 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder( num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 53fafbc4af91d..8a25088848a44 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): page_size = self.page_size max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max().item() + max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e599411b2d7e8..abca981035d9e 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 36b5853bfdcbb..b9ff113573a12 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082d..2a0c52377cc7f 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder( q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 48a9af3decac0..c69dd8415f922 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 94dd3d2629ebc..57c4d436c5b6b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -58,6 +58,8 @@ class CommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" + max_seq_len: int + """Longest context length in batch""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor @@ -107,6 +109,7 @@ def _make_metadata_with_slice( seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -128,6 +131,7 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, ) @@ -520,6 +524,7 @@ def make_local_attention_virtual_batches( query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) + max_seq_len = int(seq_lens_cpu.max()) return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, @@ -531,6 +536,7 @@ def make_local_attention_virtual_batches( num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), + max_seq_len=max_seq_len, block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index fe732c6017702..b305bc1539081 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder( q_seqlens = torch.diff(q_start_loc) max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8cd2ad12cfa30..cc2b2a139d5e9 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -582,6 +582,7 @@ class EagleProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e0bab3367cafe..d9770226b14ee 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -774,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens_np[num_reqs:].fill(0) self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True) seq_lens = self.seq_lens[:num_reqs] + max_seq_len = self.seq_lens_np[:num_reqs].max().item() # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -886,6 +887,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -2338,6 +2340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=self.max_model_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. @@ -3343,6 +3346,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(), block_table_tensor=dummy_block_table, slot_mapping=dummy_slot_mapping, causal=False, From 3b11b26b5069718a6bde11b9041681bc17369f96 Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 20 Aug 2025 18:08:29 +0200 Subject: [PATCH 06/57] [FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER (#22795) Signed-off-by: JartX Signed-off-by: tjtanaa Co-authored-by: tjtanaa --- vllm/v1/spec_decode/eagle.py | 80 ++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cc2b2a139d5e9..0a0e9fed725cb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional +from importlib.util import find_spec +from typing import Optional, Protocol import numpy as np import torch @@ -20,8 +21,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata @@ -34,6 +33,17 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -97,6 +107,20 @@ class EagleProposer: dtype=self.dtype, device=device) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree self.tree_choices: list[tuple[int, @@ -165,7 +189,7 @@ class EagleProposer: for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -225,25 +249,13 @@ class EagleProposer: # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - - # On ROCm, both AiterFlashAttention and TritonAttention - # support multi-token eagle spec decode. - if current_platform.is_rocm(): - assert isinstance( - attn_metadata, - (TritonAttentionMetadata, AiterFlashAttentionMetadata, - FlashAttentionMetadata)) - else: - # Currently, only FlashAttention supports multi-token eagle spec - # decode. This is because the code below makes assumptions about - # attn_metadata attributes available. - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -449,7 +461,7 @@ class EagleProposer: num_tokens, -1) if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_tokens) else: @@ -508,19 +520,19 @@ class EagleProposer: """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -564,9 +576,9 @@ class EagleProposer: old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) @@ -616,20 +628,18 @@ class EagleProposer: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly From dfd2382039c38be80d6c2c9b56e441b5bd7cd0ad Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Wed, 20 Aug 2025 09:52:59 -0700 Subject: [PATCH 07/57] [torch.compile] Support conditional torch.compile per module (#22269) Signed-off-by: Yong Hoon Shin --- .buildkite/test-pipeline.yaml | 2 + .../compile/piecewise/test_multiple_graphs.py | 135 +++------- tests/compile/test_decorator.py | 251 ++++++++++++++++++ vllm/compilation/decorators.py | 21 +- 4 files changed, 307 insertions(+), 102 deletions(-) create mode 100644 tests/compile/test_decorator.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2f7f1db75bfb9..745420664010a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -328,6 +328,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] @@ -341,6 +342,7 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_full_cudagraph.py + - pytest -v -s compile/piecewise/test_multiple_graphs.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental] diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index e460d70951786..f5e2d9ddb7528 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel): return x -def test_ignore_torch_compile_decorator(): - assert VLLM_USE_V1 - - # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) - - @support_torch_compile - class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + x - attn_output = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, attn_output) - x = attn_output - x = x * 3 - return x - - @ignore_torch_compile - class B(A): - ... - - @support_torch_compile - class C(B): - ... - - with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() - - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - # first run is for compile - mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - # run cudagraph captured sizes - mod_A(torch.randn(2, MLP_SIZE).cuda()) - mod_A(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() - - # B's ignore_torch_compile should override A's support_torch_compile - with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, - ), set_forward_context({}, vllm_config=vllm_config): - mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_B(torch.randn(2, MLP_SIZE).cuda()) - mod_B(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() - - # C's support_torch_compile should override B's ignore_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_C(torch.randn(2, MLP_SIZE).cuda()) - mod_C(torch.randn(1, MLP_SIZE).cuda()) - - @torch.inference_mode -def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): +def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode): with set_forward_context({}, vllm_config=vllm_config): - # First run is for compile + # warmup for the model with cudagraph_mode NONE model(inputs) - # Run CUDAGraph captured sizes - model(inputs[:2]) - model(inputs[:1]) + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(inputs[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(inputs[:1]) - output = model(inputs[:2]) + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(inputs[:2]) output = output.cpu() return output.cpu() @@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): splitting_ops=["silly.attention"], cudagraph_capture_sizes=[1, 2], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_cudagraph_captured=8, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.NO_COMPILATION, )) + cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): use_cudagraph=False, splitting_ops=["silly.attention"], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=4, num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py new file mode 100644 index 0000000000000..51f8ddd566d56 --- /dev/null +++ b/tests/compile/test_decorator.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + cudagraph_runtime_mode: CUDAGraphMode): + with set_forward_context({}, vllm_config=vllm_config): + # warmup for the model with cudagraph_mode NONE + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(torch.randn(2, MLP_SIZE).cuda()) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(torch.randn(1, MLP_SIZE).cuda()) + + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(torch.randn(2, MLP_SIZE).cuda()) + + output = output.cpu() + return output.cpu() + + +def test_ignore_torch_compile_decorator(): + # piecewise + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + + # B's ignore_torch_compile should override A's support_torch_compile + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + run_model(vllm_config, mod_B, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_C, cudagraph_runtime_mode) + + +#ย Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=True +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class B(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x + x + return x + + +#ย Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=False +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mod1(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.mod2(x) + return x + + +def test_conditional_compile_enable_if(): + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=True, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile but enable_if fn returns False + # enalbe_if will be True for B, so we expect mod1 and mod2 + # to be compiled + with compilation_counter.expect( + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + # Set kv_sharing_fast_prefill=False + # which will cause A to be compiled and B to not be compiled + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=False, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 58f70ef9ef0aa..41d9fcb824b01 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool: return getattr(cls, IGNORE_COMPILE_KEY, False) +@overload +def support_torch_compile( + *, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, +) -> Callable[[_T], _T]: + ... + + @overload def support_torch_compile( *, @@ -69,6 +77,7 @@ def support_torch_compile( cls: Optional[_T] = None, *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -118,6 +127,11 @@ def support_torch_compile( NOTE: if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. + + `enable_if` is a function that takes a `VllmConfig` object as input and + returns a boolean value indicating whether to compile the model or not. + This is useful if you want to compile the model only when certain + conditions are met. """ def cls_decorator_helper(cls: _T) -> _T: @@ -149,7 +163,8 @@ def support_torch_compile( if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, + enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -162,6 +177,7 @@ def support_torch_compile( def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -182,13 +198,14 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config + enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) + self.__class__) or not enable_compile if self.do_not_compile: return From c4477f55e581e5ef5f52bbe39cba6e0de1956444 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 20 Aug 2025 10:37:29 -0700 Subject: [PATCH 08/57] Migrate Mistral3ImagePixelInputs to TensorSchema (#21945) Signed-off-by: Benji Beck Co-authored-by: Cyrus Leung --- vllm/model_executor/models/mistral3.py | 38 ++++++++++++-------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index a647292d3a68b..438513433d3b2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -42,15 +43,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class Mistral3ImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class Mistral3ImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image """ - Shape: `(batch_size * num_images, num_channels, height, width)` - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. - """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + + # Note that `height` or `width` may be different per batch and image, + # in which case the data is passed as a list instead of a batched tensor. + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] class Mistral3PatchMerger(nn.Module): @@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) From f77a0802b758a32c5b9f7bc04e9498d77e8d99e0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 20 Aug 2025 13:57:37 -0400 Subject: [PATCH 09/57] Limit HTTP header count and size (#23267) Signed-off-by: Taneem Ibrahim Signed-off-by: Russell Bryant Co-authored-by: Taneem Ibrahim --- vllm/entrypoints/constants.py | 10 ++++++++++ vllm/entrypoints/launcher.py | 21 +++++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 2 ++ vllm/entrypoints/openai/cli_args.py | 8 ++++++++ 4 files changed, 41 insertions(+) create mode 100644 vllm/entrypoints/constants.py diff --git a/vllm/entrypoints/constants.py b/vllm/entrypoints/constants.py new file mode 100644 index 0000000000000..b5bcccc35d6c8 --- /dev/null +++ b/vllm/entrypoints/constants.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared constants for vLLM entrypoints. +""" + +# HTTP header limits for h11 parser +# These constants help mitigate header abuse attacks +H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB +H11_MAX_HEADER_COUNT_DEFAULT = 256 diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 9f4dc19fb4ab7..4e852ba594930 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -14,6 +14,8 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -26,6 +28,11 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): + """ + Start a FastAPI app using Uvicorn, with support for custom Uvicorn config + options. Supports http header limits via h11_max_incomplete_event_size and + h11_max_header_count. + """ logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -36,7 +43,21 @@ async def serve_http(app: FastAPI, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + # Extract header limit options if present + h11_max_incomplete_event_size = uvicorn_kwargs.pop( + "h11_max_incomplete_event_size", None) + h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None) + + # Set safe defaults if not provided + if h11_max_incomplete_event_size is None: + h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + if h11_max_header_count is None: + h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT + config = uvicorn.Config(app, **uvicorn_kwargs) + # Set header limits + config.h11_max_incomplete_event_size = h11_max_incomplete_event_size + config.h11_max_header_count = h11_max_header_count config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 765327da3b306..24148bcef2353 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1922,6 +1922,8 @@ async def run_server_worker(listen_address, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, **uvicorn_kwargs, ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e15f65b43082c..6e4eff5c80243 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -20,6 +20,8 @@ from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger @@ -172,6 +174,12 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" enable_log_outputs: bool = False """If set to True, enable logging of model outputs (generations) in addition to the input logging that is enabled by default.""" + h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + """Maximum size (bytes) of an incomplete HTTP event (header or body) for + h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" + h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT + """Maximum number of HTTP headers allowed in a request for h11 parser. + Helps mitigate header abuse. Default: 256.""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: From ebe56a0064f7a72a5c51d4cd6bcca165590c5bed Mon Sep 17 00:00:00 2001 From: dongluw <108290936+dongluw@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:15:18 -0400 Subject: [PATCH 10/57] Small fix for Command-A-Vision (#23268) Signed-off-by: donglu --- vllm/model_executor/models/cohere2_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 4682a8a428a03..fca1aee835b89 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -348,7 +348,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=["Cohere2ForCausalLM"]) + architectures=config.text_config.architectures) @property def dtype(self): From 0cdbf5e61ce3fd97d33b31b775d2faaadc99fbc5 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 20 Aug 2025 15:13:36 -0400 Subject: [PATCH 11/57] [Kernel/Quant] Remove the original marlin format and qqq (#23204) Signed-off-by: mgoin --- .../configs/Meta-Llama-3-8B-QQQ.yaml | 12 - .../lm-eval-harness/configs/models-large.txt | 1 - CMakeLists.txt | 2 - benchmarks/kernels/benchmark_machete.py | 23 +- csrc/quantization/machete/generate.py | 139 +- csrc/quantization/marlin/dense/LICENSE | 209 --- csrc/quantization/marlin/dense/common/base.h | 32 - csrc/quantization/marlin/dense/common/mem.h | 89 -- .../marlin/dense/marlin_cuda_kernel.cu | 1073 -------------- .../marlin/qqq/marlin_qqq_gemm_kernel.cu | 1248 ----------------- csrc/torch_bindings.cpp | 17 - tests/compile/test_full_graph.py | 6 - tests/kernels/quantization/test_machete_mm.py | 34 +- .../kernels/quantization/test_marlin_gemm.py | 83 -- tests/quantization/test_configs.py | 10 - tests/quantization/test_lm_head.py | 6 +- tests/weight_loading/models.txt | 4 - vllm/_custom_ops.py | 36 - vllm/config/__init__.py | 7 +- vllm/lora/layers.py | 3 - vllm/model_executor/layers/linear.py | 1 - .../layers/quantization/__init__.py | 6 - .../layers/quantization/marlin.py | 263 ---- .../model_executor/layers/quantization/qqq.py | 275 ---- .../utils/marlin_utils_test_qqq.py | 126 -- .../layers/quantization/utils/quant_utils.py | 85 -- 26 files changed, 92 insertions(+), 3698 deletions(-) delete mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml delete mode 100644 csrc/quantization/marlin/dense/LICENSE delete mode 100644 csrc/quantization/marlin/dense/common/base.h delete mode 100644 csrc/quantization/marlin/dense/common/mem.h delete mode 100644 csrc/quantization/marlin/dense/marlin_cuda_kernel.cu delete mode 100644 csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu delete mode 100644 vllm/model_executor/layers/quantization/marlin.py delete mode 100644 vllm/model_executor/layers/quantization/qqq.py delete mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml deleted file mode 100644 index 56ec933c9cc0e..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 -model_name: "HandH1998/QQQ-Llama-3-8b-g128" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.419 - - name: "exact_match,flexible-extract" - value: 0.416 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 27a1a9a82bd35..37eeac85c933b 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index bcbd1b52a06c6..a1deefb07f09c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) set(MARLIN_SRCS - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 975d10f2e92ec..a9c4d30d9b189 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 - - if bt.w_ch_s is not None: - s_ch = bt.w_ch_s.to(torch.float32) - else: - s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) - - if bt.w_tok_s is not None: - s_tok = bt.w_tok_s.to(torch.float32) - else: - s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) - - fn = lambda: ops.marlin_qqq_gemm( - a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - ) + raise NotImplementedError("QQQ is not supported anymore") return fn diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 88b3f9c734a30..0d14ba15937c6 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -571,78 +571,79 @@ def generate(): itertools.repeat(default_heuristic)) ] - # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) - # TODO (LucasWilkinson): Further tuning required - qqq_tile_heuristic_config = { - #### M = 257+ - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), - # "M > 256": ((128, 256), (2, 1, 1)), - "M > 256": ((128, 128), (2, 1, 1)), - #### M = 129-256 - "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), - "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 128": ((128, 256), (2, 1, 1)), - "M > 128": ((128, 128), (2, 1, 1)), - #### M = 65-128 - "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), - "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), - "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), - "M > 64": ((128, 128), (2, 1, 1)), - #### M = 33-64 - "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), - # Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), - "M > 32": ((128, 64), (2, 1, 1)), - #### M = 17-32 - "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), - "M > 16": ((256, 32), (2, 1, 1)), - #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), - None: ((128, 16), (1, 1, 1)), - } + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in qqq_tile_heuristic_config.items() - ] + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] - QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), - ] + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] - impl_configs += [ - ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) - ] + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/csrc/quantization/marlin/dense/LICENSE b/csrc/quantization/marlin/dense/LICENSE deleted file mode 100644 index 1d1e4cf9c8233..0000000000000 --- a/csrc/quantization/marlin/dense/LICENSE +++ /dev/null @@ -1,209 +0,0 @@ -Contains code from https://github.com/IST-DASLab/marlin - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h deleted file mode 100644 index 68c83d5478cf8..0000000000000 --- a/csrc/quantization/marlin/dense/common/base.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h deleted file mode 100644 index 64f9c393d77ce..0000000000000 --- a/csrc/quantization/marlin/dense/common/mem.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu deleted file mode 100644 index ea96326ed7e61..0000000000000 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm", &marlin_gemm); -} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu deleted file mode 100644 index c96d68d9b29aa..0000000000000 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Adapted from - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp - * Modified by HandH1998 - * Copyright (C) 2024 HandH1998 - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "../dense/common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "../dense/common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS_GROUP = Vec; // weight per-group quantization scales -using FragS_CHANNEL = - Vec; // weight per-channel quantization scales or activaton - // per-token quantization scales - -// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, -// cp.async.ca can support BYTES = 4, 8, 16; -// as s_tok's shape is equal to prob_m, we need set s_tok to float type, -// and cp_size = 1 float, i.e., 4 BYTES -// Asynchronous global->shared copy for activation quantizaton scales s_tok -__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 4; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// m16n8k16 tensor core mma instruction with int8 inputs and int32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - int* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in int8 tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -inline __device__ half2 float2_to_half2(float2 f) { - uint32_t res; - // NOTE(HandH1998): h0,h1 should be uint16_t, not half - uint16_t h0, h1; - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); - return reinterpret_cast(res); -} - -inline __device__ float int32_to_float(int h) { - float res; - asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); - return res; -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per channel dequant. -__device__ inline FragB dequant_per_channel(int q) { - static constexpr int MASK = 0xf0f0f0f0; - FragB frag_b; - frag_b[0] = (q & MASK); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per group dequant. -__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { - static constexpr uint32_t LO = 0x000f000f; - static constexpr uint32_t HI = 0x00f000f0; - static constexpr uint32_t EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - static constexpr uint32_t SUB = 0x64086408; - static constexpr uint32_t MUL = 0x2c002c00; - static constexpr uint32_t ADD = 0xd480d480; - *reinterpret_cast(&t0) = __hsub2( - *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - - uint16_t s = reinterpret_cast(&frag_s)[i]; - uint32_t double_s; - // pack 2xfp16 to half2 - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); - // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 - // half, respectively) - static constexpr uint32_t MAGIC_NUM = 0x64806480; - *reinterpret_cast(&t0) = __hfma2( - *reinterpret_cast(&t0), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 - // int8 into 1 uint32 - FragB frag_b; - uint32_t uint8s; - static constexpr uint32_t MASK_0246 = 0x6420; - static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(uint8s) - : "r"(t0), "r"(t1), "n"(MASK_0246)); - frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); - return frag_b; -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if constexpr (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; - D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 16; - C += 16 * thread_m_blocks * prob_n / 4; - D += 16 * thread_m_blocks * prob_n / 8; - s_tok += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 16; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 1 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - constexpr int s_tok_sh_stride = 16 * thread_m_blocks; - - constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; - - int s_group_gl_stride = prob_n / 8; - constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_group_sh_stage = s_group_sh_stride; - int s_group_gl_rd_delta = s_group_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); - a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - auto s_tok_gl_rd = threadIdx.x; - // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, - // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for - // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as - // s_tok's size is not fixed, we can not shuffle before inference we shuffle - // it when fetching s_tok from global memory to shared memory, that's why - // s_tok_sh_wr is like this - int s_tok_sh_wr = - (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; - int s_tok_sh_rd = (threadIdx.x % 32) / 4; - bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - - auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - auto s_ch_sh_wr = threadIdx.x; - int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - 2 * ((threadIdx.x % 32) % 4); - bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; - - int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; - bool s_group_sh_wr_pred; - if constexpr (group_blocks != -1) { - s_group_gl_rd = - s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_group_sh_stride * slice_col + threadIdx.x; - s_group_sh_wr = threadIdx.x; - // NOTE(HandH1998): s_group_sh_rd is related to mma output C - s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * - // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s_tok = sh_b + (stages * b_sh_stage); - int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; - int4* sh_s_group = sh_s_ch + s_ch_sh_stride; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS_GROUP frag_s_group[2][4]; - FragS_CHANNEL frag_s_tok[thread_m_blocks]; - FragS_CHANNEL frag_s_ch[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; - if (s_group_sh_wr_pred) - cp_async4(&sh_s_group_stage[s_group_sh_wr], - &s_group[s_group_gl_rd]); - s_group_gl_rd += s_group_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - int4* sh_s_group_stage = - sh_s_group + - s_group_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s_group[k % 2])[0] = - sh_s_group_stage[s_group_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - // int b_quant_shift = b_quant << 4; - FragB frag_b0, frag_b1; - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); - frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); - } else { - int b_quant_shift = b_quant << 4; - frag_b0 = dequant_per_channel(b_quant); - frag_b1 = dequant_per_channel(b_quant_shift); - } - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - int* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - // global_reduce works on INT32 elements, which are the results of INT8 GEMM. - // This is why we need another INT32 maxtrix `C` to reduce instead of the - // original half matrix `D`. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 4; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 8 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; - c_gl_wr += (4 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads * 2; - auto c_sh_wr = 2 * threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i + 1], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2) + 1], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; - int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - reinterpret_cast(&d_red1)[j]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += - reinterpret_cast(&d_red2)[j]; - } - } - if (!last) { - int4 d1, d2; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d1)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d2)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - d1; - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + - 1] = d2; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int d_gl_stride = prob_n / 8; - constexpr int d_sh_stride = 2 * thread_n_blocks + 1; - int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int d_sh_rd_delta = - d_sh_stride * (threads / (2 * thread_n_blocks)); - - int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - d_gl_wr += (2 * thread_n_blocks) * slice_col; - int d_sh_wr = - (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - d_sh_wr += 32 * (threadIdx.x / 32); - int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int d_gl_wr_end = d_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { - float2 deq_res; - deq_res.x = int32_to_float(c0) * w_s[0] * a_s; - deq_res.y = int32_to_float(c1) * w_s[1] * a_s; - ((half2*)sh)[idx] = float2_to_half2(deq_res); - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = d_sh_wr + 8 * j; - write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - } - d_sh_wr += 16 * (4 * d_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (d_gl_wr < d_gl_wr_end) { - D[d_gl_wr] = sh[d_sh_rd]; - d_gl_wr += d_gl_wr_delta; - d_sh_rd += d_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (last) { - if (s_tok_sh_wr_pred) { - cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); - } - if (s_ch_sh_wr_pred) { - cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); - } - cp_async_fence(); - } - thread_block_reduce(); - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - frag_s_tok[i][0] = - *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); - frag_s_tok[i][1] = *reinterpret_cast( - &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); - } - reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; - reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; - reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; - reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; - s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, - void* s_tok, void* s_ch, void* s_group, int prob_m, - int prob_n, int prob_k, void* workspace, - int groupsize = -1, int dev = 0, cudaStream_t stream = 0, - int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - int4* D_ptr = (int4*)D; - const float* s_tok_ptr = (const float*)s_tok; - const int4* s_ch_ptr = (const int4*)s_ch; - const int4* s_group_ptr = (const int4*)s_group; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; - D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - s_tok_ptr += 16 * thread_m_blocks * par; - } -} -} // anonymous namespace - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(size_m == s_tok.numel(), - "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(tile_size)); - TORCH_CHECK( - (size_k / tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + - ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); - - int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify N - TORCH_CHECK(s_ch.numel() == size_n, - "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(tile_size)); - if (groupsize != -1) { - TORCH_CHECK(s_group.size(1) == size_n, - "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - size_k % s_group.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); - } - - int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "Shape mismatch: size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify s_tok device, strides and dtype - TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); - TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); - TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); - - // Verify s_ch device, strides and dtype - TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); - TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); - TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); - - // Verify s_group device, strides and dtype - TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); - TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); - TORCH_CHECK(s_group.dtype() == torch::kFloat16, - "s_group's dtype is not float16"); - - // Verify workspace size - TORCH_CHECK(size_n % min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(min_thread_n)); - int min_workspace_size = (size_n / min_thread_n) * max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); - - // Alloc D matrix - auto options_d = - torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); - torch::Tensor d = torch::empty({size_m, size_n}, options_d); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - int dev = a.get_device(); - marlin_qqq_cuda( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), - s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); - - return d; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3a0ff6eaa7904..60710f62c064b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -241,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor", - {stride_tag}); - // conditionally compiled so impl in source file - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " @@ -353,15 +345,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // marlin_qqq_gemm for QQQ. - ops.def( - "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " - "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index a2fc6ffeb8b26..84178344a5f36 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -53,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): "quantization": "gptq_marlin_24" })) - if is_quant_method_supported("marlin"): - TEST_MODELS.append( - ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { - "quantization": "marlin" - })) - if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { "quantization": "AWQ" diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a842d2f1cbe8d..0e09661c955e4 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -95,23 +95,23 @@ TEST_TYPES = [ token_scale_type=None) for w_type in [scalar_types.uint4, scalar_types.uint8] for a_type in [torch.float16, torch.bfloat16]), - # QQQ style - *(TypeConfig(act_type=torch.int8, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), - *(TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), + # # QQQ style + # *(TypeConfig(act_type=torch.int8, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), + # *(TypeConfig(act_type=torch.float8_e4m3fn, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index cea7700ac3293..ad077e0b94732 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -13,11 +13,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) @@ -31,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 - marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import scalar_types @@ -449,68 +443,6 @@ def test_hqq_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("qqq"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_qqq_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, -): - int8_traits = torch.iinfo(torch.int8) - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - # Quantize activations - s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( - torch.float) - q_a = (a_input / s_a).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - # Quantize weights - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ - marlin_qqq_quantize(b_weight, num_bits, group_size) - - workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_MAX_PARALLEL) - - opcheck(torch.ops._C.marlin_qqq_gemm, - (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel, - marlin_qqq_s_group, workspace.scratch, a_input.shape[0], - b_weight.shape[1], a_input.shape[1])) - - output = ops.marlin_qqq_gemm( - q_a, - marlin_qqq_q_w, - s_a, - marlin_qqq_s_channel, - marlin_qqq_s_group, - workspace.scratch, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ) - output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - def test_marlin_gemm_subset_input(): quant_type = scalar_types.uint4b8 group_size = 128 @@ -602,18 +534,3 @@ def test_marlin_gemm_with_bias(size_m): max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 - - -def test_marlin_gemm_opcheck(): - size_m = 2048 - size_n = 4096 - size_k = 4096 - a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) - w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) - s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) - wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL).scratch - x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - torch.testing.assert_close(x, y) - opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 8cf8402436ff5..1843bffd21159 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -22,22 +22,12 @@ class ModelPair: MODEL_ARG_EXPTYPES = [ # AUTOGPTQ # compat: autogptq <=0.7.1 is_marlin_format: bool - # Model Serialized in Marlin Format should always use Marlin kernel. - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), # Model Serialized in Exllama Format. ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), # compat: autogptq >=0.8.0 use checkpoint_format: str - # Model Serialized in Marlin Format should always use Marlin kernel. - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), # Model Serialized in Exllama Format. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index 11f78a23bb4c0..5ec8b27c1571f 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -11,7 +11,6 @@ import torch from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod) -from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod) @@ -19,9 +18,7 @@ PROMPT = "On the surface of Mars, we found" MODELS_QUANT = [ ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True), - ("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), - ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False) ] @@ -41,8 +38,7 @@ def test_lm_head( lm_head_layer = model.lm_head if lm_head_quantized: assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod, - MarlinLinearMethod)) + (GPTQLinearMethod, GPTQMarlinLinearMethod)) else: assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1b797074096ed..cc18c9ff1f096 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main -marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main -qqq, HandH1998/QQQ-Llama-3-8b-g128, main -qqq, HandH1998/QQQ-Llama-3-8b, main hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main None, mgleize/fairseq2-dummy-Llama-3.2-1B, main \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 39da08847b2e7..59f2d7737f19d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -387,14 +387,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - - # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -437,25 +429,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: torch.SymInt, @@ -1348,15 +1321,6 @@ def scaled_int8_quant( return output, input_scales, input_azp -# qqq ops -def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - # gguf def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype]) -> torch.Tensor: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 5b5d477ef066b..62dfd4333bee8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1112,9 +1112,9 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" + "fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", + "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1137,7 +1137,6 @@ class ModelConfig: # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ - "marlin", "bitblas", "gptq_marlin_24", "gptq_marlin", diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index de5933d6d41e5..24a05d310d108 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device - # marlin - elif hasattr(base_layer, "B"): - return base_layer.B.device # HQQ marlin elif hasattr(base_layer, "W_q"): return base_layer.W_q.device diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d3b6b2089f426..654e2ec7b2fa0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -42,7 +42,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a4c2671225f57..ea51468422dcd 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -15,7 +15,6 @@ QuantizationMethods = Literal[ "fbgemm_fp8", "modelopt", "modelopt_fp4", - "marlin", "bitblas", "gguf", "gptq_marlin_24", @@ -25,7 +24,6 @@ QuantizationMethods = Literal[ "gptq", "compressed-tensors", "bitsandbytes", - "qqq", "hqq", "experts_int8", "neuron_quant", @@ -106,13 +104,11 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .hqq_marlin import HQQMarlinConfig from .inc import INCConfig from .ipex_quant import IPEXConfig - from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config - from .qqq import QQQConfig from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig @@ -125,7 +121,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, - "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, @@ -136,7 +131,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "ptpc_fp8": PTPCFp8Config, - "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py deleted file mode 100644 index 18d1c13373df9..0000000000000 --- a/vllm/model_executor/layers/quantization/marlin.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - - -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. - - Reference: https://github.com/IST-DASLab/marlin/tree/master - """ - - def __init__( - self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - super().__init__() - - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: - raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 - - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = 64 - - # Min in_features dim - self.min_k_threads = 128 - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return (f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(group_size, lm_head_quantized) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" - or hf_quant_cfg.get("is_marlin_format", False)) - - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "marlin") - - if is_marlin_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["MarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return MarlinLinearMethod(self) - return None - - -class MarlinLinearMethod(LinearMethodBase): - """Linear method for Marlin. - - Args: - quant_config: The Marlin quantization config. - """ - - def __init__(self, quant_config: MarlinConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del output_size # Unused. - weight_loader = extra_weight_attrs["weight_loader"] - - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) - - weight_scale_args = { - "data": - torch.empty( - input_groups, - output_size_per_partition, - device="cuda", - dtype=params_dtype, - ), - "weight_loader": - weight_loader - } - if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) - else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s", scales) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s = Parameter(layer.s.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - scales = layer.s - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, - size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py deleted file mode 100644 index 25978cb13b3ab..0000000000000 --- a/vllm/model_executor/layers/quantization/qqq.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - - -class QQQConfig(QuantizationConfig): - """Config class for QQQ - - Reference: https://arxiv.org/pdf/2406.09904 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - is_sym: bool = True, - ) -> None: - super().__init__() - self.weight_bits = weight_bits - self.group_size = group_size - self.is_sym = is_sym - - # Verify - if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: - raise ValueError( - f"QQQ does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"QQQ does not support group_size = {self.group_size}. " - f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: - raise ValueError( - f"QQQ does not support is_sym = {self.is_sym}. " - f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits - - # Tile size used by QQQ kernels. - self.tile_size = MARLIN_QQQ_TILE - - # Min out_features dim - self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N - - # Min in_features dim - self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = MARLIN_QQQ_MAX_PARALLEL - - # Permutation length used by the QQQ kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return "QQQConfig(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "qqq" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "QQQConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - return cls(weight_bits, group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QQQLinearMethod"]: - if isinstance(layer, LinearBase): - return QQQLinearMethod(self) - return None - - -class QQQLinearMethod(LinearMethodBase): - """Linear method for QQQ. - - Args: - quant_config: The QQQ quantization config. - """ - - def __init__(self, quant_config: QQQConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs["weight_loader"] - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - s_channel = ChannelQuantScaleParameter(data=torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - weight_loader=weight_loader, - output_dim=1) - - if self.quant_config.group_size == -1: - s_group_data = torch.tensor( - [], - device="cuda", - dtype=torch.half, - ) - else: - s_group_data = torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ) - - s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} - - if self.quant_config.group_size == -1: - s_group = BasevLLMParameter(**s_group_attr) - else: - s_group = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **s_group_attr) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s_channel", s_channel) - layer.register_parameter("s_group", s_group) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) - layer.s_group = Parameter(layer.s_group.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - s_ch = layer.s_channel - s_group = layer.s_group - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = s_ch.shape[1] - - x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) - - output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index 8a64bebae04c9..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: list[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: list[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 428e9e99aa881..3cfaca6230b12 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -9,8 +9,6 @@ import numpy import torch from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -386,89 +384,6 @@ def gptq_quantize_weights(w: torch.Tensor, return w_ref, w_q, w_s, g_idx, rand_perm -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ - f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / - s_channel).to(dtype=torch.half) - else: - max_q_val = 2**(num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= (2**(8 - num_bits)) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device From 582bbe6bd708d01d74d6d02d6ef59b4c3c34a7b1 Mon Sep 17 00:00:00 2001 From: bigmoyan Date: Thu, 21 Aug 2025 03:59:54 +0800 Subject: [PATCH 12/57] [Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259) Co-authored-by: wangzhengtao --- .../test_completion_with_function_calling.py | 314 +++++++++++------- tests/utils.py | 10 +- vllm/entrypoints/chat_utils.py | 17 +- vllm/entrypoints/openai/protocol.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 64 +++- .../tool_parsers/deepseekv3_tool_parser.py | 4 +- .../granite_20b_fc_tool_parser.py | 4 +- .../tool_parsers/granite_tool_parser.py | 4 +- .../openai/tool_parsers/hermes_tool_parser.py | 4 +- .../tool_parsers/internlm2_tool_parser.py | 4 +- .../openai/tool_parsers/jamba_tool_parser.py | 4 +- .../openai/tool_parsers/llama_tool_parser.py | 4 +- .../tool_parsers/minimax_tool_parser.py | 4 +- .../tool_parsers/phi4mini_tool_parser.py | 4 +- .../openai/tool_parsers/xlam_tool_parser.py | 4 +- 15 files changed, 283 insertions(+), 166 deletions(-) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index a5b081f861074..4ef5d4e8a699a 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen3-0.6B" +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": "Optional parameters for weather query", + }, + }, + "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, +] + +messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, +] + @pytest.fixture(scope="module") def server(): # noqa: F811 @@ -27,6 +148,8 @@ def server(): # noqa: F811 "hermes", "--reasoning-parser", "qwen3", + "--gpu-memory-utilization", + "0.4" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -54,129 +177,6 @@ async def client(server): async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: Union[str, dict], enable_thinking: bool): - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - "options": { - "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", - }, - }, - "required": ["country", "unit"], - "$defs": { - "WeatherOptions": { - "title": "WeatherOptions", - "type": "object", - "additionalProperties": False, - "properties": { - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - "description": "Temperature unit", - "title": "Temperature Unit", - }, - "include_forecast": { - "type": "boolean", - "default": False, - "description": - "Whether to include a 24-hour forecast", - "title": "Include Forecast", - }, - "language": { - "type": "string", - "default": "zh-CN", - "description": "Language of the response", - "title": "Language", - "enum": ["zh-CN", "en-US", "ja-JP"], - }, - }, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_forecast", - "description": "Get the weather forecast for a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["country", "days", "unit"], - }, - }, - }, - ] - - messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ - "forecast for the next 5 days, in fahrenheit?", - }, - ] if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + + +@pytest.fixture(scope="module") +def k2_server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.4", + ] + # hack to test kimi_k2 tool use tool_id format. + # avoid error in is_deepseek_mla check by setting kv_lora_rank=null + with RemoteOpenAIServer(MODEL_NAME, + args, + override_hf_configs={ + "model_type": 'kimi_k2', + 'kv_lora_rank': None + }) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def k2_client(k2_server): + async with k2_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("tool_choice", ["required"]) +async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, + stream: bool, tool_choice: str): + + if not stream: + # Non-streaming test + chat_completion = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice) + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + assert chat_completion.choices[0].message.tool_calls[ + 0].id == 'functions.get_current_weather:0' + else: + # Streaming test + output_stream = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice, + stream=True) + + output = [] + async for chunk in output_stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + for o in output: + assert o.id is None or o.id == 'functions.get_current_weather:0' diff --git a/tests/utils.py b/tests/utils.py index e98707fb44475..4dba5494665a3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import asyncio import copy import functools import importlib +import json import os import signal import subprocess @@ -101,7 +102,8 @@ class RemoteOpenAIServer: env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " @@ -120,6 +122,12 @@ class RemoteOpenAIServer: vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 74c8093f49674..87772a499f423 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1345,5 +1345,18 @@ def apply_mistral_chat_template( "template") raise ValueError(str(e)) from e -def random_tool_call_id() -> str: - return f"chatcmpl-tool-{random_uuid()}" +def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): + idx = 0 + for msg in conversation: + if msg['role'] == 'assistant': + tool_calls = msg.get('tool_calls') + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + +def make_tool_call_id(id_type:str='random', func_name=None, idx=None): + + if id_type=='kimi_k2': + return f'functions.{func_name}:{idx}' + else: + # by default return random + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 39facd4d53d32..a44868973f5d8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -38,7 +38,7 @@ from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - random_tool_call_id) + make_tool_call_id) from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger @@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=random_tool_call_id) + id: str = Field(default_factory=make_tool_call_id) type: Literal["function"] = "function" function: FunctionCall diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d57868847eedd..65aac23ee618e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -19,7 +19,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, - random_tool_call_id) + get_history_tool_calls_cnt, + make_tool_call_id) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, get_system_message, parse_chat_input, @@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing): source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + if self.model_config.hf_config.model_type == 'kimi_k2': + self.tool_call_id_type = 'kimi_k2' + else: + self.tool_call_id_type = 'random' self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: @@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, + tool_call_idx: Optional[int] = None ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing): current_tool_call = obj[-2] function_name_returned = True + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=current_tool_call["name"], + idx=tool_call_idx) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), + DeltaToolCall(id=tool_call_id, function=DeltaFunctionCall( name=current_tool_call["name"], arguments=arguments), @@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing): all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 # Always track previous_texts for comprehensive output logging previous_texts = [""] * num_choices @@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing): previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - # avoid the None + list error. if previous_token_ids: current_token_ids = previous_token_ids + as_list( @@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing): index=i) else: delta_tool_call = DeltaToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing): previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt)) + if (delta_message and delta_message.tool_calls and + delta_message.tool_calls[0].id is not None): + history_tool_call_cnt += 1 # update the previous values for the next iteration previous_texts[i] = current_text @@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 role = self.get_chat_request_role(request) for output in final_res.outputs: @@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing): assert content is not None tool_calls = TypeAdapter( list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id(id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt)) + history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - reasoning_content=reasoning_content, tool_calls=[ - tool_call_class(function=FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + tool_call_class(id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, + ensure_ascii=False))) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content) # if the request doesn't use tool choice # OR specifies to not use a tool @@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing): if (tool_call_info.content and len(tool_call_info.content) > 0): ret_content = tool_call_info.content - message = ChatMessage(role=role, reasoning_content=reasoning_content, content=ret_content) @@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing): elif choice.message.tool_calls: # For tool calls, log the function name and arguments tool_call_descriptions = [] - for tool_call in choice.message.tool_calls: - if hasattr(tool_call.function, "name") and hasattr( - tool_call.function, "arguments"): + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr( + tc.function, "arguments"): tool_call_descriptions.append( - f"{tool_call.function.name}({tool_call.function.arguments})" - ) + f"{tc.function.name}({tc.function.arguments})") tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index da4760ad1b642..ac272b0c3b205 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -6,7 +6,7 @@ from typing import Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): DeltaToolCall( index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 5508ba6a39408..824b100f357b5 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -10,7 +10,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index fcc5b7edda83f..ac517616a95b4 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index d126130ab9bc3..a6ce33af6bd00 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 92004de030d14..6ef8fadf59ac5 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 66b483d8b0f66..3b41f6034704c 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 194a144ad576e..31b19c8db4163 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,7 +10,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 226309ef293a9..283e6095013d6 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser): sent_tools.append({ "sent_name": False, "sent_arguments": "", - "id": random_tool_call_id(), + "id": make_tool_call_id(), }) while len(tool_ids) < tool_count: diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 5501028cf36b8..85dd56213c6ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -8,7 +8,7 @@ from typing import Any, Optional import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): tool_calls: list[ToolCall] = [ ToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 321718b1c950b..87cd413b37200 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser): function_name = name_match.group(1) # The test expects us to send just the name first - tool_id = random_tool_call_id() + tool_id = make_tool_call_id() delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0, From b95697d7310637399998ebf1f21a26b523aa6611 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 20 Aug 2025 13:03:37 -0700 Subject: [PATCH 13/57] [Frontend] improve error logging of chat completion (#22957) Signed-off-by: Chen Zhang --- vllm/entrypoints/openai/api_server.py | 74 +++++++++++++++++++++------ 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 24148bcef2353..14ba8aa641837 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -600,8 +600,11 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") - - generator = await handler.create_responses(request, raw_request) + try: + generator = await handler.create_responses(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -618,7 +621,11 @@ async def retrieve_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.retrieve_responses(response_id) + try: + response = await handler.retrieve_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -633,7 +640,11 @@ async def cancel_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.cancel_responses(response_id) + try: + response = await handler.cancel_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -667,9 +678,11 @@ async def create_chat_completion(request: ChatCompletionRequest, if handler is None: return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - - generator = await handler.create_chat_completion(request, raw_request) - + try: + generator = await handler.create_chat_completion(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -742,7 +755,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") - generator = await handler.create_embedding(request, raw_request) + try: + generator = await handler.create_embedding(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -770,8 +787,11 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Pooling API") - - generator = await handler.create_pooling(request, raw_request) + try: + generator = await handler.create_pooling(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -791,7 +811,11 @@ async def create_classify(request: ClassificationRequest, return base(raw_request).create_error_response( message="The model does not support Classification API") - generator = await handler.create_classify(request, raw_request) + try: + generator = await handler.create_classify(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -820,7 +844,11 @@ async def create_score(request: ScoreRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Score API") - generator = await handler.create_score(request, raw_request) + try: + generator = await handler.create_score(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -878,8 +906,12 @@ async def create_transcriptions(raw_request: Request, message="The model does not support Transcriptions API") audio_data = await request.file.read() - generator = await handler.create_transcription(audio_data, request, - raw_request) + try: + generator = await handler.create_transcription(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -919,8 +951,12 @@ async def create_translations(request: Annotated[TranslationRequest, message="The model does not support Translations API") audio_data = await request.file.read() - generator = await handler.create_translation(audio_data, request, - raw_request) + try: + generator = await handler.create_translation(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -949,7 +985,11 @@ async def do_rerank(request: RerankRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Rerank (Score) API") - generator = await handler.do_rerank(request, raw_request) + try: + generator = await handler.do_rerank(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) From bf7c99dfc40bff6844b2ae57554516922eb93b71 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Wed, 20 Aug 2025 13:17:11 -0700 Subject: [PATCH 14/57] [Perf] Speed up function `_convert_tokens_to_string_with_added_encoders` by 13.7x (#20413) Signed-off-by: Saurabh Misra Signed-off-by: Aseem Saxena Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: Aseem Saxena --- vllm/transformers_utils/detokenizer_utils.py | 25 ++++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index be1040c3e0147..101f31d39cc1f 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -23,27 +23,32 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. + # Performance improvements: avoid repeated attribute and function lookups; + # localize frequently used objects; + sub_texts: list[str] = [] current_sub_text: list[str] = [] - all_special_tokens = set(tokenizer.all_special_tokens) + convert_tokens_to_string = tokenizer.convert_tokens_to_string + added_vocab_set = set(tokenizer.get_added_vocab()) + all_special_tokens = set( + tokenizer.all_special_tokens) if skip_special_tokens else () + for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: + # Use precomputed set for skip-special check + if token in all_special_tokens: continue - if token in tokenizer.get_added_vocab(): + if token in added_vocab_set: if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] + sub_texts.append(convert_tokens_to_string(current_sub_text)) + current_sub_text.clear() sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) + sub_texts.append(convert_tokens_to_string(current_sub_text)) if spaces_between_special_tokens: return " ".join(sub_texts) - else: - return "".join(sub_texts) + return "".join(sub_texts) # 5 is an arbitrary value that should work for all From 4e51fa8cbaba2c6fd516b4615a533b0a94796516 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 20 Aug 2025 16:28:30 -0400 Subject: [PATCH 15/57] Do not use eval() to convert unknown types (#23266) Signed-off-by: Russell Bryant --- .../openai/tool_parsers/qwen3coder_tool_parser.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index cf4d0b231aee1..2501d6739e8f6 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -208,15 +208,10 @@ class Qwen3CoderToolParser(ToolParser): "valid JSON object in tool '%s', will try other " "methods to parse it.", param_value, param_name, func_name) - try: - converted_value = eval(param_value) - return converted_value - except Exception: - logger.warning( - "Parsed value '%s' of parameter '%s' cannot be " - "converted via Python `eval()` in tool '%s', " - "degenerating to string.", param_value, param_name, - func_name) + logger.warning( + "Parameter '%s' has unknown type '%s'. " + "The value will be treated as a string.", param_name, + param_type) return param_value # Extract function name From 4fbda0b20cc539f72314375c2abc6100ebac8392 Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Thu, 21 Aug 2025 05:07:28 +0800 Subject: [PATCH 16/57] [Feature] use --eplb_config to set eplb param (#20562) Signed-off-by: rongfu.leng Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: rongfu.leng Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 3 +- vllm/config/parallel.py | 108 +++++++++++++++++----- vllm/distributed/eplb/eplb_state.py | 4 +- vllm/engine/arg_utils.py | 63 +++++++++---- vllm/model_executor/models/deepseek_v2.py | 4 +- vllm/model_executor/models/glm4_moe.py | 4 +- vllm/model_executor/models/qwen3_moe.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 4 +- vllm/v1/worker/gpu_worker.py | 4 +- 9 files changed, 149 insertions(+), 52 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 62dfd4333bee8..959f111ced22e 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) -from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig +from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, + ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config from vllm.logger import init_logger diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 7a9e68f0ea332..2b716a77066ac 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -6,7 +6,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from pydantic import model_validator +from pydantic import TypeAdapter, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -32,6 +32,38 @@ logger = init_logger(__name__) DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +@config +@dataclass +class EPLBConfig: + """Configuration for Expert Parallel Load Balancing (EP).""" + + window_size: int = 1000 + """Window size for expert load recording.""" + step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `lb_window_size` steps will be used for rearranging experts. + """ + + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + + log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + @classmethod + def from_cli(cls, cli_value: str) -> "EPLBConfig": + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ + return TypeAdapter(EPLBConfig).validate_json(cli_value) + + @config @dataclass class ParallelConfig: @@ -75,22 +107,24 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - num_redundant_experts: int = 0 - """Number of redundant experts to use for expert parallelism.""" - eplb_window_size: int = 1000 - """Window size for expert load recording.""" - eplb_step_interval: int = 3000 - """ - Interval for rearranging experts in expert parallelism. - - Note that if this is greater than the EPLB window size, only the metrics - of the last `eplb_window_size` steps will be used for rearranging experts. - """ - eplb_log_balancedness: bool = False - """ - Log the balancedness each step of expert parallelism. - This is turned off by default since it will cause communication overhead. - """ + eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + """Expert parallelism configuration.""" + num_redundant_experts: Optional[int] = None + """`num_redundant_experts` is deprecated and has been replaced with + `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. + Please use `eplb_config.num_redundant_experts` instead.""" + eplb_window_size: Optional[int] = None + """`eplb_window_size` is deprecated and has been replaced with + `eplb_config.window_size`. This will be removed in v0.12.0. + Please use `eplb_config.window_size` instead.""" + eplb_step_interval: Optional[int] = None + """`eplb_step_interval` is deprecated and has been replaced with + `eplb_config.step_interval`. This will be removed in v0.12.0. + Please use `eplb_config.step_interval` instead.""" + eplb_log_balancedness: Optional[bool] = None + """`eplb_log_balancedness` is deprecated and has been replaced with + `eplb_config.log_balancedness`. This will be removed in v0.12.0. + Please use `eplb_config.log_balancedness` instead.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model @@ -237,6 +271,38 @@ class ParallelConfig: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Forward deprecated fields to their new location + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = ( + self.num_redundant_experts) + logger.warning_once( + "num_redundant_experts is deprecated and has been replaced " + "with eplb_config.num_redundant_experts. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + logger.warning_once( + "eplb_window_size is deprecated and has been replaced " + "with eplb_config.window_size. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + logger.warning_once( + "eplb_step_interval is deprecated and has been replaced " + "with eplb_config.step_interval. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + logger.warning_once( + "eplb_log_balancedness is deprecated and has been replaced " + "with eplb_config.log_balancedness. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + + # Continue with the rest of the initialization self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -275,10 +341,10 @@ class ParallelConfig: raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices now.") - if self.num_redundant_experts < 0: + if self.eplb_config.num_redundant_experts < 0: raise ValueError( "num_redundant_experts must be non-negative, but got " - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if not self.enable_expert_parallel: raise ValueError( "enable_expert_parallel must be True to use EPLB.") @@ -289,10 +355,10 @@ class ParallelConfig: f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." ) else: - if self.num_redundant_experts != 0: + if self.eplb_config.num_redundant_experts != 0: raise ValueError( "num_redundant_experts should be used with EPLB." - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 979f2a06cec9f..042acf40d67c2 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -244,7 +244,7 @@ class EplbState: dtype=torch.int32, device=device, ) - expert_load_window_size = parallel_config.eplb_window_size + expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), @@ -253,7 +253,7 @@ class EplbState: ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_step_interval + eplb_step_interval = parallel_config.eplb_config.step_interval expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6869c3f23f315..dcf78758946f9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -25,7 +25,7 @@ import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, ConvertOption, DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, @@ -305,11 +305,12 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - num_redundant_experts: int = ParallelConfig.num_redundant_experts - eplb_window_size: int = ParallelConfig.eplb_window_size - eplb_step_interval: int = ParallelConfig.eplb_step_interval - eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + num_redundant_experts: int = EPLBConfig.num_redundant_experts + eplb_window_size: int = EPLBConfig.window_size + eplb_step_interval: int = EPLBConfig.step_interval + eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -454,6 +455,9 @@ class EngineArgs: if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig( **self.compilation_config) + if isinstance(self.eplb_config, dict): + self.eplb_config = EPLBConfig.from_cli(json.dumps( + self.eplb_config)) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -661,14 +665,32 @@ class EngineArgs: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-redundant-experts", - **parallel_kwargs["num_redundant_experts"]) - parallel_group.add_argument("--eplb-window-size", - **parallel_kwargs["eplb_window_size"]) - parallel_group.add_argument("--eplb-step-interval", - **parallel_kwargs["eplb_step_interval"]) - parallel_group.add_argument("--eplb-log-balancedness", - **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument("--eplb-config", + **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--num-redundant-experts", + type=int, + help= + "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-window-size", + type=int, + help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-step-interval", + type=int, + help= + "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-log-balancedness", + action=argparse.BooleanOptionalAction, + help= + "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1244,6 +1266,16 @@ class EngineArgs: "Currently, speculative decoding is not supported with " "async scheduling.") + # Forward the deprecated CLI args to the EPLB config. + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = self.num_redundant_experts + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1257,10 +1289,7 @@ class EngineArgs: data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_redundant_experts=self.num_redundant_experts, - eplb_window_size=self.eplb_window_size, - eplb_step_interval=self.eplb_step_interval, - eplb_log_balancedness=self.eplb_log_balancedness, + eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f199da135ec76..d56224b4b7b30 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -132,10 +132,10 @@ class DeepseekV2MoE(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index aff491f9596c3..fe5e46a99826f 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -131,10 +131,10 @@ class Glm4MoE(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 05bbb0d2e8995..2812f79a66b70 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -121,11 +121,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -363,7 +363,8 @@ class Qwen3MoeModel(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d9770226b14ee..33747d6917a5a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1435,7 +1435,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model, is_dummy, is_profile, - log_stats=self.parallel_config.eplb_log_balancedness, + log_stats=self.parallel_config.eplb_config.log_balancedness, ) def get_dp_padding(self, @@ -1977,7 +1977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] - self.parallel_config.num_redundant_experts = ( + self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 22e639b97d09c..d61177d4245dd 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -515,7 +515,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) global_expert_load = None @@ -531,7 +531,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_load.shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( From 1b125004bea9f4cd120d3ce96dc1d3a2962ebace Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Aug 2025 05:15:34 +0800 Subject: [PATCH 17/57] [misc] fix multiple arch wheels for the nightly index (#23110) Signed-off-by: youkaichao --- .buildkite/generate_index.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 7045d8810493e..6b5a2a99356aa 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -8,7 +8,8 @@ template = """

Links for vLLM

- {wheel}
+ {x86_wheel}
+ {arm_wheel}
""" @@ -21,7 +22,20 @@ filename = os.path.basename(args.wheel) with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") + if "x86_64" in filename: + x86_wheel = filename + arm_wheel = filename.replace("x86_64", "aarch64") + elif "aarch64" in filename: + x86_wheel = filename.replace("aarch64", "x86_64") + arm_wheel = filename + else: + raise ValueError(f"Unsupported wheel: {filename}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + template.format( + x86_wheel=x86_wheel, + x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), + arm_wheel=arm_wheel, + arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), + ) ) From a4fbb32fab3d2f91b3672bf581565378aaa18d6c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 20 Aug 2025 17:43:17 -0400 Subject: [PATCH 18/57] Remove chunked_prefill_enabled flag in V1 MLA (#23183) Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 50 +++++++++++------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2610671f769e..646e4fec836bd 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -416,7 +416,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -426,30 +425,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -620,8 +617,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code From 10cc12ba66834e33659f1ce3a00235506db20dd5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 20 Aug 2025 17:46:47 -0400 Subject: [PATCH 19/57] Feature/mla tests (#23195) Signed-off-by: Matthew Bonanni Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_attention_backends.py | 26 +- tests/v1/attention/test_mla_backends.py | 522 ++++++++++++++++++ tests/v1/attention/utils.py | 11 +- vllm/v1/attention/backends/mla/common.py | 16 +- 4 files changed, 551 insertions(+), 24 deletions(-) create mode 100644 tests/v1/attention/test_mla_backends.py diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ac08b9052cd80..60e04ad9069e7 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] # Construct the right block table @@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, @pytest.mark.parametrize("batch_spec_name", [ "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" ]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_backend_correctness(batch_spec_name: str, model: str): @@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): """ batch_spec = BATCH_SPECS[batch_spec_name] vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens)) + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=8192) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol=rtol, atol=atol) - if not all_close: - print(f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") - print(f"[{backend_name}] output: {backend_output}") - print(f"[{backend_name}] SDPA baseline: {sdpa_output}") - assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py new file mode 100644 index 0000000000000..24070358799ef --- /dev/null +++ b/tests/v1/attention/test_mla_backends.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 MLA backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.TRITON_MLA_VLLM_V1 +] + +# Remove CUTLASS_MLA from the list if not using sm100 +if not torch.cuda.is_available() or torch.cuda.get_device_properties( + 0).major < 10: + BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + +torch.manual_seed(42) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.head_size, # latent dimension + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate an MLA KV cache with context data. + + Args: + kv_c_contexts: List of latent KV context tensors for each sequence + k_pe_contexts: List of key positional embedding context tensors + for each sequence + block_size: Size of each block + num_kv_heads: Number of KV heads (should be 1 for MLA) + head_size: Size of each head (latent dimension) + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + common_attn_metadata: Common attention metadata + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + MLA KV cache tensor + """ + batch_size = len(kv_c_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + start = start_block_idx * block_size + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + layer_names: list[str], vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, kv_c: torch.Tensor, + k_pe: torch.Tensor, kv_cache: torch.Tensor, + kv_lora_rank: int, qk_nope_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + mock_kv_b_proj) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) + + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty(num_tokens, + num_heads * v_head_dim, + dtype=query.dtype, + device=query.device) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + kv_c, + k_pe, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_backend_correctness(dist_init, batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=2048) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + total_head_size = kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank + qk_rope_head_dim == head_size, \ + f"MLA dimensions don't match: {total_head_size} != {head_size}" + scale = 1.0 / (total_head_size**0.5) + + # 2. Generate data and compute SDPA reference output for MLA + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + all_sdpa_outputs = [] + kv_c_contexts, k_pe_contexts = [], [] + + # Create shared MLA weight matrices for consistency across all sequences + W_UK = torch.randn(kv_lora_rank, + num_q_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_q_heads, + v_head_dim, + dtype=dtype, + device=device) + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate MLA tensors + # Q has both nope and rope components: + # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] + q_c = torch.randn(q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + + # KV_C (latent K/V): [s_len, kv_lora_rank] + kv_c_full = torch.randn(s_len, + kv_lora_rank, + dtype=dtype, + device=device) + + # K_PE (rope component): [s_len, 1, qk_rope_head_dim] + k_pe_full = torch.randn(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + # Determine if this is decode (single token) + # or prefill (multiple tokens) + is_decode = q_len == 1 + + # Split q into nope and rope components + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + + if is_decode: + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + else: + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, + kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, + s_len, + dtype=torch.bool, + device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + + all_sdpa_outputs.append(sdpa_out_i) + + # Inputs for vLLM MLA backends are just the new tokens + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens + all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens + + # Contextual K/V data used to populate the paged cache (MLA format) + kv_c_contexts.append(kv_c_full[:context_len]) + k_pe_contexts.append(k_pe_full[:context_len]) + + # Concatenate all sequences (no reordering needed) + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + # Create mock kv_b_proj using the same weights as reference implementation + from vllm.model_executor.layers.linear import ColumnParallelLinear + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_q_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + + # Set the mock weights to match our reference implementation + # Reshape W_UK and W_UV to match the expected kv_b_proj format + # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + + # Create metadata using original batch spec + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + for backend_name in BACKENDS_TO_TEST: + backend_output = run_attention_backend( + backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, + common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, + kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + mock_kv_b_proj) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index e547e71e0cdb7..6a08cdc56f736 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", _Backend.XFORMERS_VLLM_V1: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", + _Backend.CUTLASS_MLA: + "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", + _Backend.FLASHMLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.TRITON_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } if backend_name not in backend_map: @@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) # Set cache blocks for testing # (these may be set during initialization normally) - cache_config.num_gpu_blocks = 1000 + cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( @@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, ) device_config = DeviceConfig() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 646e4fec836bd..03028ebfe76ad 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): From c86af22f31838ee654c856279ac5110ae3fdb2cc Mon Sep 17 00:00:00 2001 From: shixianc <49539556+shixianc@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:04:21 -0700 Subject: [PATCH 20/57] [Fix] remove is_marlin param in benchmark_moe (#23286) From 4b795020eda910ecf16c289a23c4a6c119a4b43b Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:46:06 -0700 Subject: [PATCH 21/57] [EP] Add logging for experts map (#22685) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Simon Mo --- vllm/model_executor/layers/fused_moe/layer.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aa8ceda1bb25a..b16c21b7013a0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -695,6 +695,26 @@ def determine_expert_map( return (local_num_experts, expert_map) +def get_compressed_expert_map(expert_map: torch.Tensor) -> str: + """ + Compresses the expert map by removing any -1 entries. + + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. + + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ + global_indices = torch.where(expert_map != -1)[0] + local_indices = expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices)) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -795,6 +815,12 @@ class FusedMoE(CustomOp): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) From f5aa307d7795b8400d3719087c502c2a227030c7 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 20 Aug 2025 20:14:59 -0400 Subject: [PATCH 22/57] Remove duplicate entry in vllm.attention.__all__ (#23296) Signed-off-by: Russell Bryant --- vllm/attention/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a532..dcb2aa68fbee9 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ __all__ = [ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] From bbea1cefdd1a29b53355b1655f5d2ae343921f85 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 20 Aug 2025 20:18:12 -0400 Subject: [PATCH 23/57] [CI Bugfix] Fix CI by fully removing --enable-prompt-adapter (#23284) Signed-off-by: mgoin --- vllm/engine/arg_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dcf78758946f9..f3afc015f669c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -888,12 +888,6 @@ class EngineArgs: parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') - parser.add_argument('--enable-prompt-adapter', - action='store_true', - deprecated=True, - help='[DEPRECATED] Prompt adapter has been ' - 'removed. Setting this flag to True or False' - ' has no effect on vLLM behavior.') return parser From b029de9902aa3ac58806c8c17776c7074175b6db Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 20 Aug 2025 18:25:56 -0700 Subject: [PATCH 24/57] [Optimization] Make new_block_ids None if empty (#23262) Signed-off-by: Woosuk Kwon --- vllm/v1/core/kv_cache_manager.py | 30 ++++++++++++++++++++++++++---- vllm/v1/core/sched/output.py | 2 +- vllm/v1/core/sched/scheduler.py | 24 ++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 14 +++++++++----- vllm/v1/worker/tpu_model_runner.py | 14 +++++++++----- 5 files changed, 57 insertions(+), 27 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bfaa7ab08f5cf..fd0bdb2c80fc5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -37,7 +37,24 @@ class KVCacheBlocks: tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks))) - def get_block_ids(self) -> tuple[list[int], ...]: + @overload + def get_block_ids( + self, + allow_none: Literal[False] = False, + ) -> tuple[list[int], ...]: + ... + + @overload + def get_block_ids( + self, + allow_none: Literal[True] = True, + ) -> Optional[tuple[list[int], ...]]: + ... + + def get_block_ids( + self, + allow_none: bool = False, + ): """ Converts the KVCacheBlocks instance to block_ids. @@ -46,6 +63,8 @@ class KVCacheBlocks: * the outer tuple corresponds to KV cache groups * each inner list contains the block_ids of the blocks in that group """ + if allow_none and all(len(group) == 0 for group in self.blocks): + return None return tuple([blk.block_id for blk in group] for group in self.blocks) def get_unhashed_block_ids(self) -> list[int]: @@ -348,10 +367,13 @@ class KVCacheManager: """ return self.block_pool.take_events() + def get_blocks(self, request_id: str) -> KVCacheBlocks: + """Get the blocks of a request.""" + return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" - return KVCacheBlocks( - self.coordinator.get_blocks(request_id)).get_block_ids() + return self.get_blocks(request_id).get_block_ids() def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index fac07f97195bd..9ba7ec9d96932 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -91,7 +91,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[tuple[list[int], ...]] + new_block_ids: list[Optional[tuple[list[int], ...]]] num_computed_tokens: list[int] @property diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4b167da5c8f81..0b528587b9339 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -185,7 +185,7 @@ class Scheduler(SchedulerInterface): # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -288,8 +288,7 @@ class Scheduler(SchedulerInterface): # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = ( - new_blocks.get_block_ids()) + req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -496,8 +495,8 @@ class Scheduler(SchedulerInterface): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -546,8 +545,8 @@ class Scheduler(SchedulerInterface): ) # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id]) + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -555,7 +554,7 @@ class Scheduler(SchedulerInterface): scheduled_resumed_reqs, num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids, + req_to_new_blocks, ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, @@ -628,11 +627,11 @@ class Scheduler(SchedulerInterface): resumed_reqs: list[Request], num_scheduled_tokens: dict[str, int], spec_decode_tokens: dict[str, list[int]], - req_to_new_block_ids: dict[str, tuple[list[int], ...]], + req_to_new_blocks: dict[str, KVCacheBlocks], ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] num_computed_tokens: list[int] = [] use_connector = self.connector is not None @@ -655,7 +654,8 @@ class Scheduler(SchedulerInterface): # out of bounds errors. TODO: Remove this once the KVConnector # is updated to handle token IDs properly. new_token_ids.append([]) - new_block_ids.append(req_to_new_block_ids[req_id]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) num_computed_tokens.append(req.num_computed_tokens) # Because resumed_reqs is usually empty, it is more efficient to do # in-place appending so that we don't need to allocate a new list. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 33747d6917a5a..cc86f9826491f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -574,11 +574,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the block IDs. if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -594,7 +596,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9196c62377b91..0f569500cdf6b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -418,11 +418,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -438,7 +440,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. From 7be5d113d8784536b79f27f24cfa91958dc291b0 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 21 Aug 2025 09:34:24 +0800 Subject: [PATCH 25/57] [CPU] Refactor CPU W8A8 scaled_mm (#23071) Signed-off-by: jiang1.li --- .../scripts/hardware_ci/run-cpu-test.sh | 7 +- cmake/cpu_extension.cmake | 59 +- csrc/cpu/cpu_types_x86.hpp | 8 +- csrc/cpu/dnnl_helper.cpp | 346 +++++++ csrc/cpu/dnnl_helper.h | 169 ++++ csrc/cpu/dnnl_helper.hpp | 206 ---- csrc/cpu/dnnl_kernels.cpp | 494 +++++++++ csrc/cpu/quant.cpp | 951 ------------------ csrc/cpu/torch_bindings.cpp | 92 +- tests/kernels/test_onednn.py | 144 +++ vllm/_custom_ops.py | 83 ++ vllm/model_executor/layers/fused_moe/layer.py | 11 +- vllm/model_executor/layers/linear.py | 8 +- .../kernels/scaled_mm/__init__.py | 4 +- .../quantization/kernels/scaled_mm/cpu.py | 206 ++++ .../quantization/kernels/scaled_mm/cutlass.py | 4 +- vllm/model_executor/layers/utils.py | 6 + 17 files changed, 1525 insertions(+), 1273 deletions(-) create mode 100644 csrc/cpu/dnnl_helper.cpp create mode 100644 csrc/cpu/dnnl_helper.h delete mode 100644 csrc/cpu/dnnl_helper.hpp create mode 100644 csrc/cpu/dnnl_kernels.cpp delete mode 100644 csrc/cpu/quant.cpp create mode 100644 tests/kernels/test_onednn.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 57a7bc4e5f5df..9dec9f8e9eb32 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -46,6 +46,11 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run kernel tests + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -v -s tests/kernels/test_onednn.py" + # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e @@ -99,4 +104,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e0da46e2accaa..cc38cd41a5b24 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -182,17 +182,17 @@ endif() # # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) # Flag to enable ACL kernels for AARCH64 platforms -if ( VLLM_BUILD_ACL STREQUAL "ON") +if (VLLM_BUILD_ACL STREQUAL "ON") set(USE_ACL ON) else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.8.1 + GIT_TAG v3.9 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) @@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") - endif() + endif() set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") @@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) -elseif(POWER10_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.2 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src ) - - set(ONEDNN_LIBRARY_TYPE "STATIC") - set(ONEDNN_BUILD_DOC "OFF") - set(ONEDNN_BUILD_EXAMPLES "OFF") - set(ONEDNN_BUILD_TESTS "OFF") - set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") - set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - - set(DNNL_CPU_RUNTIME "OMP") - - FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) + target_link_libraries(dnnl_ext dnnl) + target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -275,7 +260,6 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) @@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ${VLLM_EXT_SRC}) add_compile_definitions(-DCPU_CAPABILITY_AVX512) endif() -elseif(POWER10_FOUND) - set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" - ${VLLM_EXT_SRC}) endif() -if (ASIMD_FOUND) + +if(USE_ONEDNN) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" + "csrc/cpu/dnnl_kernels.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 3952c43cbc727..982f7c07a13bd 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 1)) {} void save(void* ptr) const { - *reinterpret_cast<__m256i*>(ptr) = reg_low; - *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; + _mm256_storeu_si256((__m256i*)ptr, reg_low); + _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); } }; #endif diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp new file mode 100644 index 0000000000000..f3f00edb36068 --- /dev/null +++ b/csrc/cpu/dnnl_helper.cpp @@ -0,0 +1,346 @@ +#include +#include + +#include "common/memory_desc.hpp" +#include "common/memory.hpp" + +#include "dnnl_helper.h" + +static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; +} + +static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; +} + +void release_dnnl_matmul_handler(int64_t handler) { + DNNLMatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + delete ptr; +} + +template +class DNNLPrimitiveCache { + public: + using cache_value_t = std::pair; + using result_value_t = VT; + using container_t = std::list; + using value_iterator_t = typename container_t::iterator; + using map_t = std::unordered_map; + using creator_t = VT (*)(); + + public: + DNNLPrimitiveCache(size_t capacity) + : capacity_(capacity), + values_(), + key_to_value_(std::min(256lu, capacity)) { + assert(capacity > 0); + } + + template + result_value_t get_or_create(const KT& key, F&& creator) { + std::optional value = get_value(key); + if (value.has_value()) { + return value.value()->second; + } else { + return add_value({key, creator()})->second; + } + } + + size_t size() const { return values_.size(); } + + private: + void dump_data() { + std::stringstream ss; + ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec + << "\n"; + ss << "container: ["; + for (auto&& iter : values_) { + ss << "(" << iter.first << ", " << std::hex + << reinterpret_cast(iter.second.get()) << "), " << std::dec; + } + ss << "]\n"; + + ss << "map: ["; + for (auto&& iter : key_to_value_) { + ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex + << reinterpret_cast(iter.second->second.get()) << std::dec + << "), "; + } + ss << "]\n"; + std::printf("%s\n", ss.str().c_str()); + } + + value_iterator_t add_value(cache_value_t&& new_value) { + if (size() == capacity_) { + cache_value_t& last_item = values_.back(); + key_to_value_.erase(last_item.first); + values_.pop_back(); + } + + auto& added_value_ = values_.emplace_front(std::move(new_value)); + key_to_value_.emplace(added_value_.first, values_.begin()); + return values_.begin(); + } + + std::optional get_value(const KT& key) { + if (key_to_value_.size() > 0 && key == values_.begin()->first) { + return values_.begin(); + } + + auto value_map_iterator = key_to_value_.find(key); + if (value_map_iterator != key_to_value_.end()) { + values_.splice(values_.begin(), values_, value_map_iterator->second); + return value_map_iterator->second; + } else { + return {}; + } + } + + private: + const size_t capacity_; + container_t values_; + map_t key_to_value_; +}; + +DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( + const Args& args, dnnl::memory::data_type b_type) + : b_n_size_(args.b_n_size), + b_n_stride_(args.b_n_stride), + b_k_size_(args.b_k_size), + b_k_stride_(args.b_k_stride), + b_type_(b_type), + c_type_(args.c_type), + runtime_memory_ptrs_(8), + primitive_cache_size_(args.primitive_cache_size) { + assert(primitive_cache_size_ > 0); +} + +void DNNLMatMulPrimitiveHandler::prepack_weight( + void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); + dnnl::memory packed_weight(b_target_mem_desc, default_engine()); + { + dnnl::reorder(original_weight, packed_weight) + .execute(default_stream(), original_weight, packed_weight); + default_stream().wait(); + } + memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; + b_target_mem_desc_ = b_target_mem_desc; +} + +void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( + size_t index, dnnl_memory* memory_ptr) { + dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); + dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md()); + runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; +} + +std::pair +DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { + return runtime_memory_ptrs_[index]; +} + +namespace std { +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.a_qs)) ^ + hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^ + hash()(static_cast(val.c_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; +} // namespace std + +bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && + l.c_type == r.c_type; +} + +bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && + l.bias_type == r.bias_type; +} + +static std::shared_ptr +get_w8a8_class_primitive_cache( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), + dnnl::memory::data_type::s8), + use_azp_(args.use_a_zero_point), + a_qs_(args.a_quantization_strategy), + b_qs_(args.b_quantization_strategy), + m_size_cache_(nullptr) { + assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); + assert(b_qs_ != QuantizationStrategy::PER_TOKEN); + if (a_qs_ == QuantizationStrategy::PER_TOKEN) { + assert(!use_azp_); + }; + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); + a_scale_storage->set_data_handle((void*)args.a_scales_ptr); + } + if (use_azp_) { + auto&& [a_zero_point_storage, a_zero_point_mem_desc] = + get_runtime_memory_ptr(3); + a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); + } + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, + .b_k_size = b_k_size_, + .a_qs = a_qs_, + .b_qs = b_qs_, + .use_azp = use_azp_, + .c_type = c_type_}; + m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); + } + + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + return dnnl::matmul(desc); + }); +} + +void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); + if (use_azp_) { + memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), + (void*)args.b_scales_ptr); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), (void*)args.b_scales_ptr); + } + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); +} + +dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab); + dnnl::memory::desc b_md; + if (first_time) { + b_md = + dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + } else { + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_SRC, 0); + if (use_azp_) { + attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + if (key.use_bias) { + // For PER_TOKEN, bias will be applied in epilogue + assert(a_qs_ == QuantizationStrategy::PER_TENSOR); + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h new file mode 100644 index 0000000000000..54ceefced9e98 --- /dev/null +++ b/csrc/cpu/dnnl_helper.h @@ -0,0 +1,169 @@ +#ifndef DNNL_HELPER_H +#define DNNL_HELPER_H + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +namespace dnnl { +namespace impl { +struct memory_storage_t; +struct matmul_pd_t; +struct matmul_desc_t; +} // namespace impl +} // namespace dnnl +struct dnnl_memory_desc; + +template +class DNNLPrimitiveCache; + +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} + +class DNNLMatMulPrimitiveHandler { + public: + virtual ~DNNLMatMulPrimitiveHandler() = default; + + protected: + struct Args { + dnnl_dim_t b_n_size; + dnnl_dim_t b_n_stride; + dnnl_dim_t b_k_size; + dnnl_dim_t b_k_stride; + void* b_ptr; + dnnl::memory::data_type c_type; + size_t primitive_cache_size; + }; + + protected: + DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); + + void prepack_weight(void* original_b_ptr, + dnnl::memory::desc b_target_mem_desc); + + void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); + + std::pair + get_runtime_memory_ptr(size_t index); + + protected: + const dnnl_dim_t b_n_size_; + const dnnl_dim_t b_n_stride_; + const dnnl_dim_t b_k_size_; + const dnnl_dim_t b_k_stride_; + dnnl::memory::data_type b_type_; + dnnl::memory::data_type c_type_; + std::unordered_map memory_cache_; + std::vector> + runtime_memory_ptrs_; + dnnl::memory::desc b_target_mem_desc_; + int64_t primitive_cache_size_; +}; + +class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; + + struct Args : public DNNLMatMulPrimitiveHandler::Args { + bool use_a_zero_point; + QuantizationStrategy a_quantization_strategy; + QuantizationStrategy b_quantization_strategy; + float* b_scales_ptr; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + QuantizationStrategy a_qs; + QuantizationStrategy b_qs; + bool use_azp; + dnnl::memory::data_type c_type; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const int8_t* a_ptr; + const float* a_scales_ptr; + const int32_t* a_zero_points_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + W8A8MatMulPrimitiveHandler(const Args& args); + + QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } + + bool get_input_use_zero_point() const { return use_azp_; } + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + const bool use_azp_; + const QuantizationStrategy a_qs_; + const QuantizationStrategy b_qs_; + std::shared_ptr m_size_cache_; +}; + +#endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp deleted file mode 100644 index 1cb8dc5b25a66..0000000000000 --- a/csrc/cpu/dnnl_helper.hpp +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef DNNL_HELPER_HPP -#define DNNL_HELPER_HPP - -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -namespace { -template -struct DNNLType { - static constexpr dnnl::memory::data_type type = - dnnl::memory::data_type::undef; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; -}; - -template -constexpr inline dnnl::memory::data_type get_dnnl_type() { - return DNNLType>::type; -} -}; // namespace - -template -class DNNLPrimitiveHelper { - public: - // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) - // A: [M, K], row-major - // B: [K, N], column-major - // C: [M, N], row-major - // bias: [N], row-major, optional - // a_scales: [MS] - // b_scales: [NS] - // Note: Due to the limitation of oneDNN - // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is - // not supported. - - template - static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, - const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, - dnnl_dim_t K, const float* a_scales, - const float* b_scales, dnnl_dim_t MS, - dnnl_dim_t NS) { - auto&& OutputType = get_dnnl_type(); - auto&& BiasType = get_dnnl_type(); - - dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); - dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); - dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); - - dnnl::primitive_attr attr; - if constexpr (!InputNoScale) { - if (MS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_SRC, 0); - } else { - // per-token - TORCH_CHECK(false, "per-token quantization is unsupported."); - } - } - - if (NS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } else { - // per-channel - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); - } - - dnnl::matmul::primitive_desc matmul_pd; -// Create memory descriptors with format_tag::any for the primitive. This -// enables the matmul primitive to choose memory layouts for an -// optimized primitive implementation, and these layouts may differ from the -// ones provided by the user. -#ifdef __aarch64__ - auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, - dnnl::memory::format_tag::any); - auto mat_weights_md = dnnl::memory::desc( - {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); - auto mat_dst_md = - dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, - mat_weights_md, bias_md, - mat_dst_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc( - default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); - } -#else - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - bias_md, c_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - c_md, attr); - } -#endif - dnnl::matmul matmul(matmul_pd); - - auto& engine = default_engine(); - - dnnl::memory a_m(a_md, engine, (void*)a); - dnnl::memory b_m(b_md, engine, (void*)b); - dnnl::memory c_m(c_md, engine, (void*)c); - dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)a_scales); - dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)b_scales); - - auto& stream = default_stream(); - - auto mat_src_mem = a_m; - auto mat_weights_mem = b_m; - auto mat_dst_mem = c_m; -#ifdef __aarch64__ - if (matmul_pd.weights_desc() != b_m.get_desc()) { - mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); - dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); - } -#endif - if constexpr (InputNoScale) { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } else { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } - stream.wait(); - } - - private: - static dnnl::engine& default_engine() { - static dnnl::engine engine(dnnl::engine::kind::cpu, 0); - return engine; - } - - static dnnl::stream& default_stream() { - static dnnl::stream stream(default_engine()); - return stream; - } -}; -#endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp new file mode 100644 index 0000000000000..acc3b9ecde143 --- /dev/null +++ b/csrc/cpu/dnnl_kernels.cpp @@ -0,0 +1,494 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.h" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; +#endif + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = azp_val; + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const int32_t* azp, + const float* azp_adj, const scalar_t* bias, + const int64_t num_tokens, + const int64_t hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + const int64_t thread_num = omp_get_max_threads(); + if (num_tokens > thread_num) { +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + const float* input_ptr = input + i * hidden_size; + scalar_t* output_ptr = output + i * hidden_size; + int64_t j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; ++j) { + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j); + } + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j, hidden_size - j); + } + } else { + const int64_t vec_iteration = + (hidden_size + vec_elem_num - 1) / vec_elem_num; + const int64_t vec_iteration_per_thread = + (vec_iteration + thread_num - 1) / thread_num; + const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; +#pragma omp parallel for schedule(static, 1) + for (int64_t i = 0; i < thread_num; ++i) { + const int64_t start = elem_num_per_thread * i; + const int64_t end = std::min(hidden_size, elem_num_per_thread + start); + for (int64_t j = 0; j < num_tokens; ++j) { + cvt_vec_t token_scale_vec(a_scale[j]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[j] * static_cast(azp[j]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + int64_t k = start; + const float* input_ptr = input + j * hidden_size; + scalar_t* output_ptr = output + j * hidden_size; + for (; k < end - vec_elem_num; k += vec_elem_num) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k); + } + if (k < end) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k, end - k); + } + } + } + } +} +} // namespace + +int64_t create_onednn_scaled_mm_handler( + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& b_scales, // [1] or [OC] + at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(b_scales.is_contiguous()); + + W8A8MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + if (b_scales.numel() == 1) { + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + } else { + TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; + } + args.b_scales_ptr = b_scales.data_ptr(); + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + if (dynamic_act_quant) { + // dynamic per-token, bias, A scales and A zps will be applied in outside. + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; + args.use_a_zero_point = false; + } else { + // static per-tensor + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + args.use_a_zero_point = use_azp; + } + + VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", + [&] { + if (dynamic_act_quant) { + args.c_type = get_dnnl_type(); + } else { + args.c_type = get_dnnl_type(); + } + }); + + return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args)); +} + +void onednn_scaled_mm( + torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& a_scales, // [M] or [1] + const std::optional& azp, // [M] or [1] + const std::optional& azp_adj, // [M] or [1] + const std::optional& bias, // [N] + int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_scaled_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + W8A8MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + const int32_t* azp_ptr = nullptr; + if (azp.has_value()) { + azp_ptr = azp->data_ptr(); + } + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + TORCH_CHECK_EQ(a_scales.numel(), 1); + } + + W8A8MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_ptr = a.data_ptr(); + exec_args.a_m_size = a.size(0); + exec_args.bias_ptr = nullptr; + exec_args.use_bias = false; + exec_args.a_scales_ptr = nullptr; + exec_args.a_zero_points_ptr = nullptr; + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + if (bias.has_value()) { + exec_args.bias_ptr = bias->data_ptr(); + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = true; + } + exec_args.a_scales_ptr = a_scales.data_ptr(); + exec_args.a_zero_points_ptr = azp_ptr; + exec_args.c_ptr = c.data_ptr(); + ptr->execute(exec_args); + } else if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + exec_args.c_ptr = tmp_fp32_out.data_ptr(); + ptr->execute(exec_args); + if (bias.has_value()) { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + (scalar_t*)nullptr, c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr, + c.size(0), c.size(1)); + } + } + } else { + TORCH_CHECK(false, "invalid act quant type."); + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + const torch::Tensor& scale, std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int64_t stride = input.stride(0); + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + static_scaled_int8_quant_impl(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), nullptr, + num_tokens, stride, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + torch::Tensor& scale, // [batch, 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + const int64_t stride = input.stride(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, stride, + hidden_size); + } + }); +} diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp deleted file mode 100644 index 6e120b8d20a7e..0000000000000 --- a/csrc/cpu/quant.cpp +++ /dev/null @@ -1,951 +0,0 @@ -#include "cpu_types.hpp" -#include "dnnl_helper.hpp" - -namespace { -template -struct KernelVecType { - using load_vec_type = void; - using azp_adj_load_vec_type = void; - using cvt_vec_type = void; -}; - -template <> -struct KernelVecType { - using load_vec_type = vec_op::FP32Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) -template <> -struct KernelVecType { - using load_vec_type = vec_op::BF16Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; -#endif - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power architecture-specific vector type - using load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures - using load_vec_type = vec_op::FP16Vec16; -#endif - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if defined(__AVX512F__) || defined(__aarch64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#elif defined(__powerpc64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#else -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " - "support.") -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "dynamic_scaled_int8_quant_impl requires " - "AVX512/powerpc64/AArch64 support.") -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_with_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} -#endif -} // namespace - -void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Ideally we want to fuse the GEMM and the scale procedure with oneDNN - // JIT, the intermediate data is cached in registers or L1. But for now - // the oneDNN GEMM code generation only supports two quantization - // patterns: per-tensor or per-output-channel of weight. - // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * - // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN - // GEMM, then the per-token scale (and bias) is applied with the epilogue - // C=s_a * C_inter + bias. - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - bias->data_ptr(), a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } else { - // Compute C=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - nullptr, a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - } - }); -} - -void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const torch::Tensor& azp_adj, // [OC] - const std::optional& azp, // [1] or [M] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_azp only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); - } - if (azp) { - TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); - } - TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); - - // azp & bias types - TORCH_CHECK(azp_adj.dtype() == torch::kInt32); - TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); - TORCH_CHECK(!bias || bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } - } else { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C_inter=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), bias->data_ptr(), - a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), - b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); - } else { - // Compute C_inter=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - - // Compute C=C_inter - s_a * s_b * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } else { - // Per-Tensor - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } - } - }); -} - -// static-per-tensor quantization. -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale, - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value() || azp->numel() == 1); - - const int hidden_size = input.size(-1); - const int num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -// dynamic-per-token quantization. -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, // [..., 1] - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - // We dont need this - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - }); -} - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b20a054648428..c9f426bdf618a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -6,25 +6,20 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); -void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); +void release_dnnl_matmul_handler(int64_t handler); -void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const torch::Tensor& azp_adj, - const std::optional& azp, - const std::optional& bias); +int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, + const torch::Tensor& b_scales, + at::ScalarType output_type, + bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size); -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); -#endif +void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& a_scales, + const std::optional& azp, + const std::optional& azp_adj, + const std::optional& bias, + int64_t handler); void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, @@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) +#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ + defined(__powerpc64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Helper function to release oneDNN handlers + ops.def("release_dnnl_matmul_handler(int handler) -> ()", + &release_dnnl_matmul_handler); + + // Create oneDNN W8A8 handler + ops.def( + "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " + "output_type, bool dynamic_act_quant, bool use_azp, int " + "primitive_cache_size) -> int", + &create_onednn_scaled_mm_handler); + + // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization + ops.def( + "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " + "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. ops.def( @@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); -#elif defined(__powerpc64__) - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCPU, - &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py new file mode 100644 index 0000000000000..17692384ac9a9 --- /dev/null +++ b/tests/kernels/test_onednn.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for FlexAttention backend vs default backend""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import to_int8 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +NK_FACTORS = [ + (256, 128), + (4096, 4096), + (16384, 4096), + (1023, 491), + (1001, 15), +] +M_FACTORS = [ + (16, 1, 32, 128, 64), + (1, 17, 1, 31, 17), +] +CACHE_SIZES = [2] +DTYPE = [torch.bfloat16] + + +def rand_int8(shape: tuple, device: str = "cpu"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def ref_int8_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + azp: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + output_type: torch.dtype, +): + if azp is not None: + a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) + output = torch.mm((scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32))) + if bias is not None: + output += bias.float() + + return output.to(dtype=output_type) + + +def onednn_int8_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + # Test for a oneDNN kernel with per-tensor / per-token activation + # quantization and per-tensor / per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) + b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) + + scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) + scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + + if use_azp: + azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 + azp = (azp / scale_a).round().to(dtype=torch.int32) + azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32) + else: + azp = None + azp_adj = None + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + else: + bias = None + + handler = ops.create_onednn_scaled_mm( + b, + scale_b, + out_dtype, + not per_tensor_a_quant, + use_azp, + primitive_cache_size, + ) + + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + if use_bias: + # To test runtime bias setting + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, + out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("per_tensor_a_scale", [True, False]) +@pytest.mark.parametrize("per_tensor_b_scale", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_azp", [True, False]) +@pytest.mark.parametrize("output_type", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_int8_scaled_gemm( + n: int, + k: int, + m_list: tuple[int], + per_tensor_a_scale: bool, + per_tensor_b_scale: bool, + use_bias: bool, + use_azp: bool, + output_type: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_int8_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + per_tensor_a_quant=per_tensor_a_scale, + per_tensor_b_quant=per_tensor_b_scale, + use_bias=use_bias, + use_azp=use_azp, + out_dtype=output_type, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 59f2d7737f19d..3081aff114fc1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1827,3 +1827,86 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): M = mat1.size(0) N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) + + +class CPUDNNLGEMMHandler: + + def __init__(self) -> None: + self.handler: Optional[int] = None + self.n = -1 + self.k = -1 + + def __del__(self): + if self.handler is not None: + torch.ops._C.release_dnnl_matmul_handler(self.handler) + + +def create_onednn_scaled_mm( + weight: torch.Tensor, # [K, N] + weight_scales: torch.Tensor, + output_type: torch.dtype, + dynamic_quant: bool, + use_azp: bool, + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( + weight, weight_scales, output_type, dynamic_quant, use_azp, + primitive_cache_size) + return handler + + +def onednn_scaled_int8_quant(input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True): + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + token_num = input.numel() // input.shape[-1] + input = input.view((token_num, input.shape[-1])) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((token_num, 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + + +def onednn_scaled_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + output: torch.Tensor, + input_scale: Optional[torch.Tensor], + input_zp: Optional[torch.Tensor], + input_zp_adj: Optional[torch.Tensor], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, + input_zp_adj, bias, dnnl_handler.handler) + + return output diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b16c21b7013a0..fcc6987d26bb2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -360,10 +360,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): elif current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: from vllm.model_executor.layers.fused_moe import cpu_fused_moe - dtype = layer.w13_weight.dtype + from vllm.model_executor.layers.utils import ( + check_cpu_sgl_kernel) + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() if (envs.VLLM_CPU_SGL_KERNEL - and torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16): + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): packed_w13_weight = torch.ops._C.convert_weight_packed( layer.w13_weight) assert packed_w13_weight.size() == layer.w13_weight.size() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 654e2ec7b2fa0..9b1ab7af0ac84 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -199,11 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel N, K = layer.weight.size() dtype = layer.weight.dtype - if (torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16 and N % 32 == 0 - and K % 32 == 0): + if check_cpu_sgl_kernel(N, K, dtype): packed_weight = torch.ops._C.convert_weight_packed( layer.weight) assert packed_weight.size() == layer.weight.size() @@ -215,7 +214,8 @@ class UnquantizedLinearMethod(LinearMethodBase): else: logger.warning( "CPU SGL kernels require Intel AMX support," - " bfloat16 weight, IC and OC are divisible by 32.") + " bf16/fp16/int8 weight, IC and OC are divisible by " + "32 and 16.") layer.use_cpu_sgl = False def apply(self, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 18f5ce04fd355..2bc68ab3ebd18 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -6,6 +6,8 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( + CPUScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -18,7 +20,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py new file mode 100644 index 0000000000000..59d2b5bce962e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.layers.utils import check_cpu_sgl_kernel +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CPUScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cpu(): + return False, "CPUScaledMM requires running on CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = getattr(layer, self.w_q_name) + dtype = weight.dtype + N, K = weight.size() + if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype)): + self.linear_method = self._apply_weights_sgl + self.process_weights_for_sgl(layer) + else: + self.linear_method = self._apply_weights_onednn + self.process_weights_for_onednn(layer) + + def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Transpose to [K, N] for convenience + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # oneDNN kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + azp = (int8_traits.min - + range_min / scale).round().to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # Different from cutlass, oneDNN kernels only need the AZP adjustment + # term for dynamic quantization. And s_b should be folded into the + # term. Such as: + # s_a * s_b * [(A - zp_a)B] + bias = + # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = + # s_a * GEMM_output - s_a * zp_a * adj + bias + if not (self.config.input_symmetric + and self.config.is_static_input_scheme): + weight = getattr(layer, self.w_q_name) + weight_scale = getattr(layer, self.w_s_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) + azp_adj = azp_adj * weight_scale.squeeze() + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + weight = getattr(layer, self.w_q_name) + self.dnnl_handler = ops.create_onednn_scaled_mm( + weight, + getattr(layer, self.w_s_name), + torch.get_default_dtype(), + getattr(layer, self.i_s_name) is None, + not self.config.input_symmetric, + 32, + ) + # weight is prepacked and maintained by the dnnl_handler, + # release the original weight + setattr(layer, self.w_q_name, None) + del weight + + def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + # WEIGHT + weight = getattr(layer, self.w_q_name) + packed_weight = torch.ops._C.convert_weight_packed(weight) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(packed_weight, requires_grad=False)) + + if layer.bias is not None: + bias = layer.bias + layer.register_parameter( + "bias_fp32", + torch.nn.Parameter(bias.float().data, requires_grad=False)) + + # WEIGHT SCALE + # CPU SGL kernels only support per-channel. + # For per-tensor quant, convert to the per-channel case. + weight_scale = getattr(layer, self.w_s_name) + if not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.linear_method( + layer, + x, + bias, + ) + + def _apply_weights_onednn( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( + x, i_s, i_zp, self.config.input_symmetric) + + m = x.size(0) + n = self.dnnl_handler.n + out = torch.empty((m, n), dtype=x.dtype) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, + bias) + + return out + + def _apply_weights_sgl( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + return torch.ops._C.int8_scaled_mm_with_quant( + x, + w_q, + w_s, + layer.bias_fp32 if bias is not None else None, + x.dtype, + True, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 6ddd4a9ec4233..2f982f96b0d04 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -25,8 +25,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): def can_implement( cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - if (not current_platform.is_cuda() and not current_platform.is_cpu()): - return False, "CutlassScaledMM requires running on CUDA or CPU." + if not current_platform.is_cuda(): + return False, "CutlassScaledMM requires running on CUDA." return True, None diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 48a347a8f5611..2897f75b3129e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -142,6 +142,12 @@ direct_register_custom_op( ) +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): + return (torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 + and n % 16 == 0) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, From 2461d9e562e5852555c76e0dbed06979f9c6c688 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 21 Aug 2025 11:05:20 +0800 Subject: [PATCH 26/57] [CI/Build] Split out mm processor tests (#23260) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 15 +++++++++++---- .../{ => processing}/test_tensor_schema.py | 7 +++---- vllm/model_executor/models/cohere2_vision.py | 2 ++ 3 files changed, 16 insertions(+), 8 deletions(-) rename tests/models/multimodal/{ => processing}/test_tensor_schema.py (98%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 745420664010a..5869ae21d5c7e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -545,6 +545,15 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' +- label: Multi-Modal Processor Test + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - pytest -v -s models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Models Test (Standard) mirror_hardwares: [amdexperimental] torch_nightly: true @@ -554,9 +563,7 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model - - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 @@ -567,7 +574,7 @@ steps: - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing - label: Multi-Modal Models Test (Extended) 2 mirror_hardwares: [amdexperimental] diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py similarity index 98% rename from tests/models/multimodal/test_tensor_schema.py rename to tests/models/multimodal/processing/test_tensor_schema.py index 143b4c8fc8c49..79164f02c3398 100644 --- a/tests/models/multimodal/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -24,9 +24,9 @@ from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.engine.core import EngineCore as V1EngineCore -from ...conftest import VllmRunner -from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS -from ..utils import dummy_hf_overrides +from ....conftest import VllmRunner +from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ...utils import dummy_hf_overrides ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", @@ -147,7 +147,6 @@ def get_model_id_to_test( return filtered_results -@pytest.mark.core_model @pytest.mark.parametrize( "model_arch, model_id", get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index fca1aee835b89..179cc2af8eb3f 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -170,6 +170,8 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): # The current implementation of get_number_of_image_patches # is incorrect, so we patch it here. + # TODO: Revert once + # https://github.com/huggingface/transformers/pull/40312 is released. # return image_processor.get_number_of_image_patches(image_height, # image_width, {}) From 3663870c72da246d81d8bd8f5c059890fb3f3f5d Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Thu, 21 Aug 2025 06:08:51 +0300 Subject: [PATCH 27/57] [V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035) Signed-off-by: asafg Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg --- docs/usage/v1_guide.md | 2 +- .../models/language/generation/test_hybrid.py | 20 ++---- vllm/config/compilation.py | 1 + .../layers/mamba/mamba_mixer.py | 66 ++++++++++++++++--- vllm/model_executor/models/jamba.py | 8 ++- vllm/model_executor/models/mamba.py | 7 +- vllm/v1/attention/backends/mamba1_attn.py | 37 +++++------ vllm/v1/attention/backends/mamba2_attn.py | 45 ++----------- vllm/v1/attention/backends/mamba_attn.py | 55 ++++++++++++++++ 9 files changed, 154 insertions(+), 87 deletions(-) create mode 100644 vllm/v1/attention/backends/mamba_attn.py diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 54af970ea842d..9bf0c5842c6be 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index aee0a50336c09..f8c0eaa8cf3a2 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -54,16 +54,14 @@ V1_SUPPORTED_MODELS = [ "tiiuae/Falcon-H1-0.5B-Base", ] +FULL_CUDA_GRAPH_MODELS = [ + "ai21labs/Jamba-tiny-dev", + "Zyphra/Zamba2-1.2B-instruct", +] + # Avoid OOM MAX_NUM_SEQS = 4 -# Once we add support for FCG in Mamba1, this list will be removed and tests -# all test cases will use enforce_eager=False -ENFORCE_EAGER_MODELS_V1 = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", -] - @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -101,19 +99,13 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - - if model in ENFORCE_EAGER_MODELS_V1: - enforce_eager = True - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -373,7 +365,7 @@ def test_distributed_correctness( ) -@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_full_cuda_graph( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56a2183f8e2c1..c654485f4fe9c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -336,6 +336,7 @@ class CompilationConfig: "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.mamba_mixer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3c7322260df43..a24e72778b34b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp): def forward(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: - return CustomOp.forward(self, hidden_states, mamba_cache_params) + CustomOp.forward(self, hidden_states, output, mamba_cache_params) else: - return self.forward_cuda( + torch.ops.vllm.mamba_mixer( hidden_states, - mamba_cache_params, + output, + self.prefix, ) def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): """ Run the Mamba-1 SSM pipeline. @@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes else: assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None @@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp): has_initial_states = None if context_lens_tensor is not None: has_initial_states = context_lens_tensor > 0 + num_padded_decodes = attn_metadata.num_decode_tokens # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp): num_decodes = attn_metadata.num_decode_tokens # token count (=request) has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 + num_actual_tokens = num_prefill_tokens + num_decode_tokens prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, @@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp): num_decode_tokens, num_prefills, num_decodes, + num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp): else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] - return out + output[:num_actual_tokens] = out def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None @@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode( num_decode_tokens: int, num_prefills: int, num_decodes: int, + num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes + if envs.VLLM_USE_V1: # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) - gate_d, gate_p = torch.split(gate, - [num_decode_tokens, num_prefill_tokens], + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], dim=-1) + + # num_padded_decodes accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, [num_decodes, num_prefills], dim=0) + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_decodes if num_prefills > 0 else None) + num_padded_decodes if num_prefills > 0 else None) has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states is not None and num_prefills > 0) else None else: @@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode( query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, ) + + +def mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None) + + +def mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer", + op_func=mamba_mixer, + mutates_args=["output"], + fake_impl=mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0b32d6f256590..3c1a0b68df56e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -10,6 +10,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -154,10 +155,10 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -278,6 +279,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class JambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f4aaf0c6f467c..f02499a4f96b5 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import MambaConfig from vllm import envs +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params) + return output, residual +@support_torch_compile class MambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 6cdc509083ae9..97a1aa86dda0d 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,16 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -31,24 +31,11 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int + num_padded_decodes: int class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], - ): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): def build( self, @@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder( decode_threshold=1)) has_initial_states = None + padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + state_indices_for_decode = state_indices_tensor[:num_decodes] + padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_for_decode, non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID return Mamba1AttentionMetadata( query_start_loc=query_start_loc, @@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + num_padded_decodes=padded_decodes, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ace078e2b27c6..ed30884fdbc94 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,18 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -88,29 +88,14 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba2AttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - - reorder_batch_threshold: ClassVar[int] = 1 + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) def build(self, common_prefix_len: int, @@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder( state_indices_tensor=state_indices_tensor, ) return attn_metadata - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py new file mode 100644 index 0000000000000..07ef7cb69a160 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +from typing import ClassVar, TypeVar + +import torch + +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + +M = TypeVar("M") + + +class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): + reorder_batch_threshold: ClassVar[int] = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names + + self.compilation_config = vllm_config.compilation_config + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) \ No newline at end of file From f94bf9b924afe2e720b864590c9798b911e77e66 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 20 Aug 2025 23:09:39 -0400 Subject: [PATCH 28/57] [Compile] Fix Compile Warning SM100 Cutlass MLA (#23287) Signed-off-by: yewentao256 --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290df..6dd6f269f3dc9 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); From 655a09f6538e6b09af23771dcc4fcebd72a15b23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=A5=87=28yann=20qi=29?= <51905299+yannqi@users.noreply.github.com> Date: Thu, 21 Aug 2025 12:08:52 +0800 Subject: [PATCH 29/57] [Model][VLM] Support R-4B Model (#23246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: yannqi Signed-off-by: ๆจๅฅ‡(yann qi) <51905299+yannqi@users.noreply.github.com> Signed-off-by: Cyrus Leung Co-authored-by: yannqiyang Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 23 ++++ .../vision_language_multi_image.py | 34 ++++++ .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/rvl.py | 103 ++++++++++++++++++ 7 files changed, 165 insertions(+) create mode 100644 vllm/model_executor/models/rvl.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7308d0010690a..831bfb1e939e6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -652,6 +652,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | โœ…๏ธŽ | โœ…๏ธŽ | โœ…๏ธŽ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | โœ…๏ธŽ | โœ…๏ธŽ | +| `RForConditionalGeneration` | R-VL-4B | T + IE+ | `YannQi/R-4B` | | โœ…๏ธŽ | โœ…๏ธŽ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | โœ…๏ธŽ | โœ…๏ธŽ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | โœ…๏ธŽ | | โœ…๏ธŽ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | โœ…๏ธŽ | โœ…๏ธŽ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 88bbbfdfbd188..e7a7a30dd31a6 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1436,6 +1436,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# R-4B +def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "YannQi/R-4B" + + prompts = [ + f"<|im_start|>user \n{question}<|im_end|><|im_start|>assistant\n" + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1622,6 +1644,7 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "step3": run_step3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index eabd9453f3c51..d9242efa85470 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "YannQi/R-4B" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" @@ -1193,6 +1226,7 @@ model_example_map = { "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, + "rvl": load_r_vl, "smolvlm": load_smolvlm, "step3": load_step3, "tarsier": load_tarsier, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 02aecfad8281d..adc8b2510d677 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -316,6 +316,7 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "YannQi/R-4B", "Skywork/Skywork-R1V-38B", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", "stepfun-ai/step3", diff --git a/tests/models/registry.py b/tests/models/registry.py index 6e6acfb8cd228..4f69f90b6aae1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -489,6 +489,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", + trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 78ef270598b8e..39a3e425a46df 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -217,6 +217,7 @@ _MULTIMODAL_MODELS = { "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py new file mode 100644 index 0000000000000..efdb010046634 --- /dev/null +++ b/vllm/model_executor/models/rvl.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class RVLProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + +class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class RVLMultiModalProjector(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=RVLProcessingInfo, + dummy_inputs=RVLDummyInputsBuilder, +) +class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = RVLMultiModalProjector(config) From 8993073dc1a7e2d31eda85812b76789046ae7c28 Mon Sep 17 00:00:00 2001 From: QiliangCui Date: Thu, 21 Aug 2025 04:15:20 +0000 Subject: [PATCH 30/57] [CI] Delete images older than 24h. (#23291) Signed-off-by: Qiliang Cui --- .buildkite/scripts/tpu/cleanup_docker.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh index 209d9c4341cdd..740d81fb39bb0 100755 --- a/.buildkite/scripts/tpu/cleanup_docker.sh +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune --force --filter "until=72h" --all + docker volume prune -f && docker system prune --force --filter "until=24h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." From f64ee61d9e7014a5f230a8347186b952dbe483de Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 21 Aug 2025 00:21:05 -0400 Subject: [PATCH 31/57] [CI] Block the cu126 wheel build while broken (#23285) Signed-off-by: mgoin --- .buildkite/release-pipeline.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index e20ce54ca795a..f96c38bf57db7 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -27,7 +27,12 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build CUDA 12.6 wheel" + key: block-build-cu126-wheel + depends_on: ~ + - label: "Build wheel - CUDA 12.6" + depends_on: block-build-cu126-wheel id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge From f571ff8eb6d9117c6a418f7f925921968dff8ac8 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 20 Aug 2025 21:28:32 -0700 Subject: [PATCH 32/57] [Sampler] Support returning final logprobs (#22387) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill Co-authored-by: Woosuk Kwon --- docs/usage/v1_guide.md | 7 ++- tests/v1/sample/test_logprobs.py | 10 ++-- vllm/config/__init__.py | 30 ++++++---- vllm/engine/arg_utils.py | 1 + vllm/v1/sample/ops/topk_topp_sampler.py | 65 ++++++++++---------- vllm/v1/sample/sampler.py | 79 +++++++++++++++++++------ vllm/v1/sample/tpu/sampler.py | 2 +- 7 files changed, 125 insertions(+), 69 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 9bf0c5842c6be..b89768913681e 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -154,12 +154,15 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the modelโ€™s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the modelโ€™s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 8bd142e87b06e..e835c029634ce 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): assert len(logprob) == vocab_size -@pytest.mark.parametrize( - "logprobs_mode", - ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. @@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, for logprobs in output.logprobs: for token_id in logprobs: logprob = logprobs[token_id] - if "logprobs" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, + LogprobsMode.PROCESSED_LOGPROBS): assert logprob.logprob <= 0 if logprob.logprob > 0: positive_values = positive_values + 1 total_token_with_logprobs = total_token_with_logprobs + 1 assert total_token_with_logprobs >= len(results[0].outputs) - if "logits" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGITS, + LogprobsMode.PROCESSED_LOGITS): assert positive_values > 0 del llm diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 959f111ced22e..2973cb92d195b 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -257,11 +257,16 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", - "processed_logits"] MMEncoderTPMode = Literal["weights", "data"] +class LogprobsMode(enum.Enum): + RAW_LOGITS = "raw_logits" + RAW_LOGPROBS = "raw_logprobs" + PROCESSED_LOGITS = "processed_logits" + PROCESSED_LOGPROBS = "processed_logprobs" + + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: @@ -363,12 +368,13 @@ class ModelConfig: specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying logit processors, like bad words. - Processed means the values after applying such processors. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding @@ -2586,7 +2592,7 @@ class MultiModalConfig: skip_mm_profiling: bool = False """ - When enabled, skips multimodal memory profiling and only profiles with + When enabled, skips multimodal memory profiling and only profiles with language backbone model during engine initialization. This reduces engine startup time but shifts the responsibility to users for @@ -2649,24 +2655,24 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. """ dimensions: Optional[int] = None """ - Reduce the dimensions of embeddings if model + Reduce the dimensions of embeddings if model support matryoshka representation. """ ## for classification models activation: Optional[bool] = None """ - Whether to apply activation function to the classification outputs. + Whether to apply activation function to the classification outputs. """ ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the reward outputs. """ step_tag_id: Optional[int] = None """ @@ -2692,9 +2698,9 @@ class PoolerConfig: max_embed_len: Optional[int] = None """ - Maximum input length allowed for embedding generation. When set, allows + Maximum input length allowed for embedding generation. When set, allows inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring + This parameter enables accepting long inputs without requiring VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds max_embed_len, it will be handled according to the original max_model_len validation logic. Defaults to None (i.e. set to max_model_len). diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f3afc015f669c..b0f50b4429a82 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -516,6 +516,7 @@ class EngineArgs: model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", + choices=[f.value for f in LogprobsMode], **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e0434c8f3d713..7bd4a5a380ac0 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -8,6 +8,7 @@ import torch.nn as nn from packaging import version from vllm import envs +from vllm.config import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self): + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: super().__init__() - if current_platform.is_cuda(): + self.logprobs_mode = logprobs_mode + # flashinfer optimization does not apply if intermediate + # logprobs/logits after top_k/top_p need to be returned + if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, + LogprobsMode.PROCESSED_LOGPROBS + ) and current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): @@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native - elif current_platform.is_tpu(): - self.forward = self.forward_tpu else: self.forward = self.forward_native + if current_platform.is_tpu(): + self.apply_top_k_top_p = apply_top_k_top_p_tpu + else: + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + logits_to_return = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), logits_to_return def forward_cuda( self, @@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + if (k is None and p is None) or generators: + if generators: + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. - return flashinfer_sample(logits.contiguous(), k, p, generators) - - def forward_tpu( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> torch.Tensor: - logits = apply_top_k_top_p_tpu(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return flashinfer_sample(logits.contiguous(), k, p, generators), None def apply_top_k_top_p_tpu( diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82f51298f1b59..70ec8a0c26ddf 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + import torch import torch.nn as nn @@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + """ + A layer that samples the next tokens from the model's outputs + with the following steps in order: - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + 1. If logprobs are requested: + a) If `logprobs_mode` is `raw_logprobs`, compute logprobs + as the final logprobs to return. + b) If `logprobs_mode` is `raw_logits`, clone the logits + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. + 5. Apply logit processors which are not argmax-invariant, + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: + a) If not `all_random`, perform greedy sampling. If `all_greedy`, + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. + c) Apply logit processors which are argmax-invariant, by default + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. + f) If `all_random` or temperature >= epsilon (1e-5), return the + randomly sampled tokens and final logprobs if requested. Else, + return the greedily sampled tokens and logprobs if requested. + 8. Gather the logprobs of the top `max_num_logprobs` and sampled token + (if requested). Note that if the sampled token is within the top + `max_num_logprobs`, the logprob will be eventually merged in + `LogprobsProcessor` during output processing. Therefore, the + final output may contain either `max_num_logprobs + 1` or + `max_num_logprobs` logprobs. + 9. Return the final `SamplerOutput`. + """ + + def __init__(self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): super().__init__() - self.topk_topp_sampler = TopKTopPSampler() + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode @@ -34,13 +76,11 @@ class Sampler(nn.Module): # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). - # TODO(rob): provide option for logprobs post sampling. - # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: raw_logprobs = logits.clone() # Use float32 for the logits. @@ -57,15 +97,10 @@ class Sampler(nn.Module): # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Get the process logprobs or logits. - if num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "processed_logits": - raw_logprobs = logits.clone() - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) + sampled, processed_logprobs = self.sample(logits, sampling_metadata) + if processed_logprobs is not None: + raw_logprobs = processed_logprobs # Convert sampled token ids to int64 (long) type to ensure compatibility # with subsequent operations that may use these values as indices. # This conversion is necessary because FlashInfer sampling operations @@ -105,7 +140,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Sample logits based on sampling metadata. The various logits processing functions called in this method @@ -119,7 +154,13 @@ class Sampler(nn.Module): else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: - return greedy_sampled + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + processed_logprobs = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None @@ -132,7 +173,7 @@ class Sampler(nn.Module): logits = processor.apply(logits) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, processed_logprobs = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -140,7 +181,7 @@ class Sampler(nn.Module): ) if greedy_sampled is None: - return random_sampled + return random_sampled, processed_logprobs sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, @@ -148,7 +189,7 @@ class Sampler(nn.Module): random_sampled, out=greedy_sampled, # Reuse tensor ) - return sampled + return sampled, processed_logprobs def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 2c9f4892bc247..04545d587e4a9 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -65,7 +65,7 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, _ = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, From 0c31e28e9520d96c451cc7f023fd0f0af549766a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 21 Aug 2025 13:03:00 +0800 Subject: [PATCH 33/57] [Bugfix] Fix extra whitespace in strings caused by newline (#23272) Signed-off-by: DarkLight1337 --- benchmarks/benchmark_dataset.py | 6 ++++-- examples/offline_inference/vision_language.py | 15 +++++++-------- vllm/benchmarks/datasets.py | 6 ++++-- vllm/model_executor/model_loader/tpu.py | 11 ++++++----- vllm/model_executor/models/hyperclovax_vision.py | 9 ++++----- vllm/model_executor/models/phi4mm.py | 6 +++--- vllm/transformers_utils/configs/eagle.py | 4 ++-- 7 files changed, 30 insertions(+), 27 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index e1a856026c4ae..2ea4f9ccaff2b 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -958,8 +958,10 @@ class InstructCoderDataset(HuggingFaceDataset): for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e7a7a30dd31a6..8d97ba2668263 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -283,8 +283,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -767,15 +769,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user