mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:36:14 +08:00
[Quantization] support logical_widths for fp8 marlin (#30962)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
54c8924384
commit
ee52d9901d
@ -142,10 +142,20 @@ def prepare_fp8_layer_for_marlin(
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if weight_block_size is None:
|
||||
logical_widths = getattr(layer, "logical_widths", [])
|
||||
if scales.nelement() == 1:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (1, 1) =>(repeat)=> (1, size_n)
|
||||
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
||||
elif scales.nelement() == len(logical_widths):
|
||||
# tensor-wise quantization with logical_widths ->
|
||||
# channel-wise quantization
|
||||
assert sum(logical_widths) == part_size_n, (
|
||||
f"Sum of logical_widths ({sum(logical_widths)}) must be equal "
|
||||
f"to part_size_n ({part_size_n})"
|
||||
)
|
||||
lw_tensor = scales.new_tensor(logical_widths, dtype=torch.int64)
|
||||
scales = scales.view(1, -1).repeat_interleave(lw_tensor, dim=1)
|
||||
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
||||
assert part_size_n % scales.nelement() == 0
|
||||
s_size = scales.nelement()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user