diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 93d238a0524d8..1fb5223b07d76 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -142,10 +142,20 @@ def prepare_fp8_layer_for_marlin( # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales if weight_block_size is None: + logical_widths = getattr(layer, "logical_widths", []) if scales.nelement() == 1: # tensor-wise quantization -> channel-wise quantization # (1, 1) =>(repeat)=> (1, size_n) scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) + elif scales.nelement() == len(logical_widths): + # tensor-wise quantization with logical_widths -> + # channel-wise quantization + assert sum(logical_widths) == part_size_n, ( + f"Sum of logical_widths ({sum(logical_widths)}) must be equal " + f"to part_size_n ({part_size_n})" + ) + lw_tensor = scales.new_tensor(logical_widths, dtype=torch.int64) + scales = scales.view(1, -1).repeat_interleave(lw_tensor, dim=1) elif scales.nelement() > 1 and scales.nelement() != part_size_n: assert part_size_n % scales.nelement() == 0 s_size = scales.nelement()