mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 23:27:50 +08:00
Fix FBGEMM integration (#18002)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
dc9905368d
commit
f065de4e88
@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignore_list):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignore_list,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||
|
||||
if size_k_first:
|
||||
assert layer.weight.shape == (part_size_k, part_size_n)
|
||||
@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||
del layer.weight_scale_inv
|
||||
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if weight_block_size is None:
|
||||
if scales.nelement() == 1:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (1, 1) =>(repeat)=> (1, size_n)
|
||||
@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.T.contiguous()
|
||||
block_n = layer.weight_block_size[0]
|
||||
block_n = weight_block_size[0]
|
||||
scales = scales.repeat_interleave(block_n, 1)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[:, :part_size_n]
|
||||
@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||
|
||||
for name in ["w13", "w2"]:
|
||||
if name + "_weight_scale" in dir(layer):
|
||||
@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if weight_block_size is None:
|
||||
if scales.nelement() == e:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
||||
@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.permute(0, 2, 1)
|
||||
block_n = layer.weight_block_size[0]
|
||||
block_n = weight_block_size[0]
|
||||
scales = scales.repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[..., :size_n].contiguous()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user