mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 01:24:29 +08:00
Fix FBGEMM integration (#18002)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
dc9905368d
commit
f065de4e88
@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignore_list):
|
if is_layer_skipped(prefix=prefix,
|
||||||
|
ignored_layers=self.ignore_list,
|
||||||
|
fused_mapping=self.packed_modules_mapping):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return FBGEMMFp8LinearMethod(self)
|
return FBGEMMFp8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
|
|
||||||
part_size_n = layer.output_size_per_partition
|
part_size_n = layer.output_size_per_partition
|
||||||
part_size_k = layer.input_size_per_partition
|
part_size_k = layer.input_size_per_partition
|
||||||
|
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||||
|
|
||||||
if size_k_first:
|
if size_k_first:
|
||||||
assert layer.weight.shape == (part_size_k, part_size_n)
|
assert layer.weight.shape == (part_size_k, part_size_n)
|
||||||
@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||||
del layer.weight_scale_inv
|
del layer.weight_scale_inv
|
||||||
|
|
||||||
if layer.weight_block_size is None:
|
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||||
group_size = -1
|
|
||||||
else:
|
|
||||||
group_size = layer.weight_block_size[1]
|
|
||||||
|
|
||||||
# marlin kernel only support channel-wise and group-wise quantization
|
# marlin kernel only support channel-wise and group-wise quantization
|
||||||
# we need to convert the scales
|
# we need to convert the scales
|
||||||
if layer.weight_block_size is None:
|
if weight_block_size is None:
|
||||||
if scales.nelement() == 1:
|
if scales.nelement() == 1:
|
||||||
# tensor-wise quantization -> channel-wise quantization
|
# tensor-wise quantization -> channel-wise quantization
|
||||||
# (1, 1) =>(repeat)=> (1, size_n)
|
# (1, 1) =>(repeat)=> (1, size_n)
|
||||||
@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||||
if not size_k_first:
|
if not size_k_first:
|
||||||
scales = scales.T.contiguous()
|
scales = scales.T.contiguous()
|
||||||
block_n = layer.weight_block_size[0]
|
block_n = weight_block_size[0]
|
||||||
scales = scales.repeat_interleave(block_n, 1)
|
scales = scales.repeat_interleave(block_n, 1)
|
||||||
# size_n may not divisible by block_size[0]
|
# size_n may not divisible by block_size[0]
|
||||||
scales = scales[:, :part_size_n]
|
scales = scales[:, :part_size_n]
|
||||||
@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
e = layer.num_experts
|
e = layer.num_experts
|
||||||
k = layer.hidden_size
|
k = layer.hidden_size
|
||||||
n = layer.intermediate_size_per_partition
|
n = layer.intermediate_size_per_partition
|
||||||
|
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||||
|
|
||||||
# WORKSPACE
|
# WORKSPACE
|
||||||
device = layer.w13_weight.device
|
device = layer.w13_weight.device
|
||||||
@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
|
|
||||||
# WEIGHT SCALES
|
# WEIGHT SCALES
|
||||||
# Permute scales
|
# Permute scales
|
||||||
if layer.weight_block_size is None:
|
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||||
group_size = -1
|
|
||||||
else:
|
|
||||||
group_size = layer.weight_block_size[1]
|
|
||||||
|
|
||||||
for name in ["w13", "w2"]:
|
for name in ["w13", "w2"]:
|
||||||
if name + "_weight_scale" in dir(layer):
|
if name + "_weight_scale" in dir(layer):
|
||||||
@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
|
|
||||||
# marlin kernel only support channel-wise and group-wise quantization
|
# marlin kernel only support channel-wise and group-wise quantization
|
||||||
# we need to convert the scales
|
# we need to convert the scales
|
||||||
if layer.weight_block_size is None:
|
if weight_block_size is None:
|
||||||
if scales.nelement() == e:
|
if scales.nelement() == e:
|
||||||
# tensor-wise quantization -> channel-wise quantization
|
# tensor-wise quantization -> channel-wise quantization
|
||||||
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
||||||
@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||||
if not size_k_first:
|
if not size_k_first:
|
||||||
scales = scales.permute(0, 2, 1)
|
scales = scales.permute(0, 2, 1)
|
||||||
block_n = layer.weight_block_size[0]
|
block_n = weight_block_size[0]
|
||||||
scales = scales.repeat_interleave(block_n, 2)
|
scales = scales.repeat_interleave(block_n, 2)
|
||||||
# size_n may not divisible by block_size[0]
|
# size_n may not divisible by block_size[0]
|
||||||
scales = scales[..., :size_n].contiguous()
|
scales = scales[..., :size_n].contiguous()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user