[Quantization] support logical_widths for fp8 marlin (#30962)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Jinzhen Lin 2025-12-21 04:02:57 +08:00 committed by GitHub
parent 54c8924384
commit ee52d9901d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -142,10 +142,20 @@ def prepare_fp8_layer_for_marlin(
# marlin kernel only support channel-wise and group-wise quantization # marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales # we need to convert the scales
if weight_block_size is None: if weight_block_size is None:
logical_widths = getattr(layer, "logical_widths", [])
if scales.nelement() == 1: if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization # tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n) # (1, 1) =>(repeat)=> (1, size_n)
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
elif scales.nelement() == len(logical_widths):
# tensor-wise quantization with logical_widths ->
# channel-wise quantization
assert sum(logical_widths) == part_size_n, (
f"Sum of logical_widths ({sum(logical_widths)}) must be equal "
f"to part_size_n ({part_size_n})"
)
lw_tensor = scales.new_tensor(logical_widths, dtype=torch.int64)
scales = scales.view(1, -1).repeat_interleave(lw_tensor, dim=1)
elif scales.nelement() > 1 and scales.nelement() != part_size_n: elif scales.nelement() > 1 and scales.nelement() != part_size_n:
assert part_size_n % scales.nelement() == 0 assert part_size_n % scales.nelement() == 0
s_size = scales.nelement() s_size = scales.nelement()