From ee52d9901d368da35a80048c299ccb53e78a9a01 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 21 Dec 2025 04:02:57 +0800 Subject: [PATCH] [Quantization] support logical_widths for fp8 marlin (#30962) Signed-off-by: Jinzhen Lin Signed-off-by: Jinzhen Lin Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../layers/quantization/utils/marlin_utils_fp8.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 93d238a0524d8..1fb5223b07d76 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -142,10 +142,20 @@ def prepare_fp8_layer_for_marlin( # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales if weight_block_size is None: + logical_widths = getattr(layer, "logical_widths", []) if scales.nelement() == 1: # tensor-wise quantization -> channel-wise quantization # (1, 1) =>(repeat)=> (1, size_n) scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) + elif scales.nelement() == len(logical_widths): + # tensor-wise quantization with logical_widths -> + # channel-wise quantization + assert sum(logical_widths) == part_size_n, ( + f"Sum of logical_widths ({sum(logical_widths)}) must be equal " + f"to part_size_n ({part_size_n})" + ) + lw_tensor = scales.new_tensor(logical_widths, dtype=torch.int64) + scales = scales.view(1, -1).repeat_interleave(lw_tensor, dim=1) elif scales.nelement() > 1 and scales.nelement() != part_size_n: assert part_size_n % scales.nelement() == 0 s_size = scales.nelement()