mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:57:10 +08:00
[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order (#11528)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
2cbeedad09
commit
eb5cb5e528
@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
@ -289,13 +291,20 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
moe_quant_params = {
|
||||
"num_experts": num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__ ==
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
@ -312,19 +321,30 @@ class FusedMoE(torch.nn.Module):
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
tp_rank: int,
|
||||
load_full_w2: bool = False):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
@ -364,15 +384,21 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
@ -387,8 +413,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
@ -416,7 +441,7 @@ class FusedMoE(torch.nn.Module):
|
||||
]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size is used.
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
@ -424,11 +449,11 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size is
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = ~shard_dim
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
@ -480,7 +505,8 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
tp_rank=tp_rank,
|
||||
load_full_w2=getattr(param, "load_full_w2", False))
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
|
||||
@ -303,7 +303,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed":
|
||||
@ -312,17 +312,18 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
})
|
||||
|
||||
w13_qweight = Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(torch.empty(num_experts,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
@ -331,13 +332,14 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = intermediate_size // self.quant_config.group_size
|
||||
num_groups_w2 = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size * 2,
|
||||
intermediate_size_per_partition * 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
@ -353,12 +355,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
w13_qzeros = Parameter(
|
||||
torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -75,24 +76,26 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
@ -254,6 +257,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
self.group_size = config.group_size
|
||||
self.actorder = config.actorder
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
@ -266,9 +270,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full")
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
@ -276,35 +287,45 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size //
|
||||
self.packed_factor,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.int32),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
intermediate_size //
|
||||
self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# In the case where we have actorder/g_idx,
|
||||
# we do not partition the w2 scales
|
||||
load_full_w2 = self.actorder and self.group_size != -1
|
||||
w2_scales_size = (intermediate_size_full
|
||||
if load_full_w2 else intermediate_size_per_partition)
|
||||
|
||||
self.is_k_full = (not self.actorder) or (
|
||||
intermediate_size_per_partition == intermediate_size_full)
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = intermediate_size // self.group_size
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w13_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
@ -316,6 +337,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
|
||||
requires_grad=False)
|
||||
@ -335,18 +357,18 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
@ -364,7 +386,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -422,24 +444,55 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
size_k2 = layer.w2_weight_packed.shape[2]
|
||||
size_k13 = layer.w13_weight_packed.shape[2]
|
||||
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
|
||||
# when running models with grouped act order,
|
||||
# resort to g_idx values provided in checkpoint
|
||||
if self.actorder == "group":
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
|
||||
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(
|
||||
layer.w13_weight_g_idx[e]).to(torch.int32)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(
|
||||
layer.w2_weight_g_idx[e]).to(torch.int32)
|
||||
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
|
||||
w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
|
||||
w2_g_idx_sort_indices[e]]
|
||||
|
||||
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
|
||||
else:
|
||||
layer.w13_weight_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
@ -511,9 +564,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.num_bits,
|
||||
)
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
@ -62,7 +62,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
**kwargs):
|
||||
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
@ -52,7 +52,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
int8_dtype = torch.int8
|
||||
@ -64,26 +64,29 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=int8_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=int8_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.float32),
|
||||
w13_scale = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
|
||||
@ -386,8 +386,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
@ -402,30 +402,34 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1 and intermediate_size % block_k != 0):
|
||||
if (tp_size > 1
|
||||
and intermediate_size_per_partition % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
@ -446,7 +450,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@ -456,7 +461,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
||||
@ -317,7 +317,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@ -326,7 +326,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
# Supports only sym for now (no zp)
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
scales_size2 = intermediate_size // self.quant_config.group_size
|
||||
scales_size2 = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
@ -342,7 +343,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -353,7 +354,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size // self.quant_config.pack_factor,
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
@ -365,7 +367,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.half),
|
||||
requires_grad=False,
|
||||
)
|
||||
@ -385,7 +387,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
@ -414,7 +417,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -435,7 +438,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
||||
@ -60,24 +60,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype),
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user