Fix FBGEMM integration (#18002)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-12 19:02:07 -04:00 committed by GitHub
parent dc9905368d
commit f065de4e88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 13 deletions

View File

@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignore_list):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignore_list,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None

View File

@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
weight_block_size = getattr(layer, "weight_block_size", None)
if size_k_first:
assert layer.weight.shape == (part_size_k, part_size_n)
@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
group_size = -1 if weight_block_size is None else weight_block_size[1]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if weight_block_size is None:
if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = layer.weight_block_size[0]
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
weight_block_size = getattr(layer, "weight_block_size", None)
# WORKSPACE
device = layer.w13_weight.device
@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# WEIGHT SCALES
# Permute scales
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
group_size = -1 if weight_block_size is None else weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = layer.weight_block_size[0]
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()