From c55c02899853a31b1a8a0f48b2ca3ea9dae80586 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Sun, 24 Aug 2025 12:42:38 +0800 Subject: [PATCH 01/13] [gpt-oss] Streaming Output for Python Tool (#23409) Signed-off-by: zjy0516 --- vllm/entrypoints/openai/serving_responses.py | 70 ++++++++++++-------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 6b131bbb04d19..5adcb310e3468 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -1069,7 +1069,48 @@ class OpenAIServingResponses(OpenAIServing): delta=ctx.parser.last_content_delta, sequence_number=-1, )) - + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif (ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type= + "response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + )) if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if (self.tool_server is not None @@ -1165,30 +1206,6 @@ class OpenAIServingResponses(OpenAIServing): and self.tool_server.has_tool("python") and previous_item.recipient is not None and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code="", - container_id="auto", - outputs=[], - status="in_progress", - ), - )) - yield _send_event( - openai_responses_types. - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - )) - # TODO: do we need to add delta event here? yield _send_event( openai_responses_types. ResponseCodeInterpreterCallCodeDoneEvent( @@ -1196,7 +1213,8 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - code=previous_item.content[0].text)) + code=previous_item.content[0].text, + )) yield _send_event( openai_responses_types. ResponseCodeInterpreterCallInterpretingEvent( From 053278a5dc7a81d751f8e63c1ed793062b32cbce Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 23 Aug 2025 21:55:53 -0700 Subject: [PATCH 02/13] Migrate Pixtral inputs to TensorSchema (#23472) Signed-off-by: Benji Beck --- vllm/model_executor/models/pixtral.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c01074e2122bb..461b9c85d1c22 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -48,6 +48,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -68,15 +69,20 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -381,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), From 9dc30b7068ae07ceca89663e9f8403d00217256d Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Sat, 23 Aug 2025 21:56:17 -0700 Subject: [PATCH 03/13] [Bugfix] Add strong reference to CUDA pluggable allocator callbacks (#23477) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: youkaichao Co-authored-by: Eric Marcus Co-authored-by: youkaichao --- vllm/device_allocator/cumem.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed97ee..7963fb15c4191 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -152,8 +152,13 @@ class CuMemAllocator: self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" @@ -162,7 +167,7 @@ class CuMemAllocator: allocation_handle, self.current_tag) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" @@ -212,9 +217,9 @@ class CuMemAllocator: def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. From a75277285ba6fa178c9cb9185fec7ec5943fff6b Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 23 Aug 2025 21:56:56 -0700 Subject: [PATCH 04/13] Migrate Paligemma inputs to TensorSchema (#23470) Signed-off-by: Benji Beck --- vllm/model_executor/models/paligemma.py | 67 +++++++++++-------------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 7d6a6207c7c89..95abb190e0a46 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -21,6 +21,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -32,19 +33,27 @@ from .vision import get_vision_encoder_info logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - - -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImagePixelInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +class PaliGemmaImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, @@ -279,19 +288,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -301,22 +297,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": h, + "w": w + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( From e76e23354033f167b778f5e49fd384e301681d65 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Sun, 24 Aug 2025 02:18:04 -0400 Subject: [PATCH 05/13] [kernel] Support W4A8 on Hopper (#23198) Signed-off-by: czhu-cohere --- CMakeLists.txt | 27 ++ benchmarks/kernels/benchmark_machete.py | 33 ++ benchmarks/kernels/weight_shapes.py | 6 + .../cutlass_w4a8/w4a8_mm_entry.cu | 418 ++++++++++++++++++ csrc/torch_bindings.cpp | 20 + .../kernels/quantization/test_cutlass_w4a8.py | 259 +++++++++++ vllm/_custom_ops.py | 48 ++ .../compressed_tensors/compressed_tensors.py | 43 +- .../compressed_tensors/schemes/__init__.py | 4 +- .../schemes/compressed_tensors_w4a8_fp8.py | 160 +++++++ .../kernels/mixed_precision/__init__.py | 3 + .../kernels/mixed_precision/cutlass.py | 114 +++++ 12 files changed, 1128 insertions(+), 7 deletions(-) create mode 100644 csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu create mode 100644 tests/kernels/quantization/test_cutlass_w4a8.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a1deefb07f09c..aca42c3fe5553 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + # if CUDA endif endif() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a9c4d30d9b189..1b1c3b321cce4 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -284,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -385,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394afbd..9a057990bda5f 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ WEIGHT_SHAPES = { ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 0000000000000..fdac47c425d61 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,418 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid harcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4edb7af50f102..7ae054dc19fbd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 0000000000000..7832f8179d0ec --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), + (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), + (1024, 4096, 8192), (1024, 8192, 4096)] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', + '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', + '128x256_1x1x1', '128x256_2x1x1' +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + *( + TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, + max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights(w, + wtype, + group_size=group_size, + zero_points=zero_points) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( + torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> Tensors: + m, n, k = shape + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + False) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +def mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=1e-3, + atol=1e-3) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + + w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) + w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0043456e0009a..3e3b43ce2abe3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -474,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -1032,6 +1056,30 @@ def machete_prepack_B( group_scales_type) +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, + b_channel_scales, a_token_scales, + out_type, maybe_schedule) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 637a84372990a..ce74375aab426 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig): format ) if format is not None else is_activation_quantization_format( quant_format) - if act_quant_format: - input_activations = quant_config.get("input_activations") + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where @@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w4a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and is_symmetric and is_dynamic) + + def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w4a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) @@ -405,6 +429,13 @@ class CompressedTensorsConfig(QuantizationConfig): if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba7b9..cac65cca5093f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) @@ -21,5 +22,6 @@ __all__ = [ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 0000000000000..f6cc49c2316ba --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError("W4A8 kernels require group quantization " \ + "with group size 128") + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee92cd..4bcfcd04b3d8b 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 @@ -24,6 +26,7 @@ from vllm.platforms import current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 0000000000000..f1d49693fc016 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, + group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 "\ + "(Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "CUTLASS W4A8, only supported int4" + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return False, "K and N must be divisible by 128, got "\ + f"{c.partition_weight_shape}" + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b( + x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + # TODO(czhu): support loading channel scales + self.w_ch_s = torch.ones((c.partition_weight_shape[1], ), + dtype=torch.float32, + device='cuda') + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None, "bias not supported by CUTLASS W4A8" + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm(a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=self.w_ch_s) + + return output.reshape(out_shape) From 1b9b16649c10453fe25ff28313dffa175194a84b Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Sun, 24 Aug 2025 16:06:34 +0800 Subject: [PATCH 06/13] [Misc] update dict parse to EPLBConfig from json dumps to dict unpacking (#23305) Signed-off-by: rongfu.leng --- vllm/config/parallel.py | 9 +-------- vllm/engine/arg_utils.py | 3 +-- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index f7b8b1d0a5658..9ea883d4a03cd 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -6,7 +6,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from pydantic import TypeAdapter, model_validator +from pydantic import model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -56,13 +56,6 @@ class EPLBConfig: This is turned off by default since it will cause communication overhead. """ - @classmethod - def from_cli(cls, cli_value: str) -> "EPLBConfig": - """Parse the CLI value for the compilation config. - -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. - """ - return TypeAdapter(EPLBConfig).validate_json(cli_value) - @config @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 965264ee3097a..3ab1115f14462 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -455,8 +455,7 @@ class EngineArgs: self.compilation_config = CompilationConfig( **self.compilation_config) if isinstance(self.eplb_config, dict): - self.eplb_config = EPLBConfig.from_cli(json.dumps( - self.eplb_config)) + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() From 5e021b4981c14724b57744df182d876e6a07e4b9 Mon Sep 17 00:00:00 2001 From: TeeKen Lau <13831887+teekenl@users.noreply.github.com> Date: Sun, 24 Aug 2025 20:12:47 +1000 Subject: [PATCH 07/13] (Misc): add missing test for zero truncation size. (#23457) Signed-off-by: teekenl --- tests/entrypoints/openai/test_truncation.py | 22 +++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 18ddc493c9283..121c0413e1af7 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -64,6 +64,28 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + + assert error_details["type"] == "BadRequestError" + assert "This model's maximum context length is" in error_details["message"] + assert "tokens in the input for embedding generation" in error_details[ + "message"] + assert "Please reduce the length of the input" in error_details["message"] + + @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 From 416f05929ac66f5ae364936b70087fc60cacee4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Sun, 24 Aug 2025 20:52:24 +0800 Subject: [PATCH 08/13] [New Model]Donut model (#23229) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- docs/models/supported_models.md | 1 + examples/offline_inference/dolphin.py | 311 ++++++++++++ .../encoder_decoder_multimodal.py | 46 ++ .../multimodal/processing/test_common.py | 2 + tests/models/registry.py | 3 + vllm/engine/llm_engine.py | 2 +- vllm/model_executor/models/donut.py | 398 +++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/swin.py | 475 ++++++++++++++++++ vllm/multimodal/profiling.py | 2 +- vllm/v1/engine/processor.py | 2 +- 11 files changed, 1240 insertions(+), 3 deletions(-) create mode 100644 examples/offline_inference/dolphin.py create mode 100644 vllm/model_executor/models/donut.py create mode 100644 vllm/model_executor/models/swin.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 297d98142b5f2..3159d3bd1c819 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -615,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 0000000000000..d2ba27cd1e027 --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"{prompt}" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"{elem['prompt']}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e7..655f9f3fce7ae 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "What time is the coffee break?", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index adc8b2510d677..a604d11f0e769 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -160,6 +160,7 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, "ovis2_5": False, @@ -270,6 +271,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/tests/models/registry.py b/tests/models/registry.py index 25dbbd7fa9832..b34c6f2e5dc84 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -513,6 +513,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351e87c..dbf8d3ba50146 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1822,7 +1822,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 0000000000000..b1f6a0af6b3de --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _validate_pixel_values( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + # size = self.processor_config["size"] + h, w = self.config.encoder.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_dims}. You supplied {actual_dims}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return DonutImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f094806..ebf78771e40a4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -252,6 +252,7 @@ _MULTIMODAL_MODELS = { "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 0000000000000..30b441f5b4df0 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2da9b4c72189a..ea2efbdd8b524 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 69f8e531e01b1..219857dc7b778 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -389,7 +389,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( From e2db1164a186a9d2592299aaa5aea3f013711db3 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 24 Aug 2025 21:30:47 +0800 Subject: [PATCH 09/13] [Model] Enable BLOOM on V1 (#23488) Signed-off-by: DarkLight1337 --- docs/models/supported_models.md | 2 +- vllm/model_executor/models/bloom.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 3159d3bd1c819..8fb1019f2bdfb 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -328,7 +328,7 @@ th { | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6e4a399f3cc6e..126404584892f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,7 +313,7 @@ class BloomModel(nn.Module): return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() From ad78868450a3596bed37dac05be9049019953e94 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 24 Aug 2025 14:03:36 -0700 Subject: [PATCH 10/13] [Misc] Remove unused slot_mapping buffer (#23502) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed4a4e55f1212..ec9887b8010a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -254,9 +254,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None From c7fc6b1354a20f5dbdd2fb806cd4b7da27d46f63 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Sun, 24 Aug 2025 15:35:41 -0700 Subject: [PATCH 11/13] fix incompatibililty with non cuda platform for nvfp4 (#23478) Signed-off-by: Lu Fang Co-authored-by: Lucia (Lu) Fang --- vllm/compilation/fusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 413948799de35..0d8d562514e31 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -47,8 +47,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 - kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[ + kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): From 47455c424f62a20b75a7cfd872e17c5ba11c9f3a Mon Sep 17 00:00:00 2001 From: Didier Durand <2927957+didier-durand@users.noreply.github.com> Date: Mon, 25 Aug 2025 02:04:04 +0200 Subject: [PATCH 12/13] [Doc: ]fix various typos in multiple files (#23487) Signed-off-by: Didier Durand Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .buildkite/nightly-benchmarks/nightly-descriptions.md | 2 +- docs/deployment/frameworks/anything-llm.md | 2 +- docs/design/fused_moe_modular_kernel.md | 2 +- docs/design/metrics.md | 2 +- docs/design/paged_attention.md | 2 +- docs/features/quantization/inc.md | 2 +- docs/getting_started/installation/cpu.md | 6 +++--- docs/getting_started/installation/intel_gaudi.md | 4 ++-- vllm/config/cache.py | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d383e..37e2980eea974 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b2085ca..0b41e73b030cc 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: +- On the bottom left of open settings, AI Providers --> LLM: - LLM Provider: Generic OpenAI - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 4b917ab408eec..3c4c7d2102170 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b01838883f31e..b24364247b3f8 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35caf30..d87b2a639df12 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index 13b151bc7f380..5e86e9388f328 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e494..e76ec35e1edcb 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -170,7 +170,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +179,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +190,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 61b2b02aa10ba..ff912efec9ca8 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -261,13 +261,13 @@ Lower value corresponds to less usable graph memory reserved for prefill stage, User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode +- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode - `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. !!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae11dec3ca5e2..a9550d4390ad6 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -116,7 +116,7 @@ class CacheConfig: In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) + necessary for implementing this optimization in some models (e.g. Gemma3n) """ def compute_hash(self) -> str: From 504d91431499e302bbd5a3d8a1432cd427ec8d5d Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Sun, 24 Aug 2025 18:06:35 -0700 Subject: [PATCH 13/13] [Perf] Add Triton config for DeepSeek V3 FP8 EP32 H200 (#23504) Signed-off-by: Ming Yang --- .../kernels/benchmark_w8a8_block_fp8.py | 2 +- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 154 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 26 +++ .../quantization/utils/configs/README.md | 3 + 4 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/README.md diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65ecd..e648a91077fdb 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,8 +11,8 @@ from datetime import datetime from typing import Any import torch -import tqdm import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( _w8a8_block_fp8_matmul, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..d677d69c57a25 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..fbca5ce05d018 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 0000000000000..1110ced4fa063 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files.