mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:16:24 +08:00
[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)
This commit is contained in:
parent
0408efc6d0
commit
7ca9934fe7
@ -1,5 +1,7 @@
|
|||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
||||||
|
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
|
||||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
||||||
|
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
|
||||||
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
|
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
|
||||||
@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
}
|
}
|
||||||
# need full intermediate size pre-sharding for WNA16 act order
|
# need full intermediate size pre-sharding for WNA16 act order
|
||||||
if (self.quant_method.__class__.__name__ ==
|
if (self.quant_method.__class__.__name__
|
||||||
"CompressedTensorsWNA16MoEMethod"):
|
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|||||||
@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
# Currently assuming is_k_full is always True
|
intermediate_size_full = extra_weight_attrs.pop(
|
||||||
# (input size per partition is the same as full input size)
|
"intermediate_size_full")
|
||||||
# Supports only sym for now (no zp)
|
|
||||||
|
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||||
|
intermediate_size_per_partition == intermediate_size_full)
|
||||||
|
|
||||||
if self.quant_config.group_size != -1:
|
if self.quant_config.group_size != -1:
|
||||||
scales_size13 = hidden_size // self.quant_config.group_size
|
scales_size13 = hidden_size // self.quant_config.group_size
|
||||||
scales_size2 = (intermediate_size_per_partition //
|
w2_scales_size = (intermediate_size_full
|
||||||
self.quant_config.group_size)
|
if self.quant_config.desc_act else
|
||||||
|
intermediate_size_per_partition)
|
||||||
|
scales_size2 = (w2_scales_size // self.quant_config.group_size)
|
||||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||||
else:
|
else:
|
||||||
scales_size13 = 1
|
scales_size13 = 1
|
||||||
@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
layer.register_parameter("w2_scales", w2_scales)
|
layer.register_parameter("w2_scales", w2_scales)
|
||||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||||
|
# dont shard the w2 scales when running act order
|
||||||
|
set_weight_attrs(w2_scales,
|
||||||
|
{"load_full_w2": self.quant_config.desc_act})
|
||||||
# up_proj scales
|
# up_proj scales
|
||||||
w13_qzeros = torch.nn.Parameter(
|
w13_qzeros = torch.nn.Parameter(
|
||||||
torch.empty(num_experts,
|
torch.empty(num_experts,
|
||||||
@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||||
|
# dont shard the w2 scales when running act order
|
||||||
|
set_weight_attrs(w2_qzeros,
|
||||||
|
{"load_full_w2": self.quant_config.desc_act})
|
||||||
w13_g_idx = torch.nn.Parameter(
|
w13_g_idx = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||||
num_bits=self.quant_config.quant_type.size_bits,
|
num_bits=self.quant_config.quant_type.size_bits,
|
||||||
).to(orig_dtype)
|
is_k_full=self.is_k_full).to(orig_dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user