[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order (#11528)

Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2025-01-23 16:40:33 -05:00 committed by GitHub
parent 2cbeedad09
commit eb5cb5e528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 243 additions and 148 deletions

View File

@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
@ -289,13 +291,20 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
@ -312,19 +321,30 @@ class FusedMoE(torch.nn.Module):
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
tp_rank: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
@ -364,15 +384,21 @@ class FusedMoE(torch.nn.Module):
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
@ -387,8 +413,7 @@ class FusedMoE(torch.nn.Module):
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
@ -416,7 +441,7 @@ class FusedMoE(torch.nn.Module):
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
@ -424,11 +449,11 @@ class FusedMoE(torch.nn.Module):
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
@ -480,7 +505,8 @@ class FusedMoE(torch.nn.Module):
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,

View File

@ -303,7 +303,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
@ -312,17 +312,18 @@ class AWQMoEMethod(FusedMoEMethodBase):
FusedMoeWeightScaleSupported.GROUP.value,
})
w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qweight = Parameter(
torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
intermediate_size_per_partition,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
@ -331,13 +332,14 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
intermediate_size_per_partition * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
@ -353,12 +355,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qzeros = Parameter(
torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -75,24 +76,26 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.static_input_scales = not self.input_quant.dynamic
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
@ -254,6 +257,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
@ -266,9 +270,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
@ -276,35 +287,45 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size //
self.packed_factor,
2 * intermediate_size,
dtype=torch.int32),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
intermediate_size //
self.packed_factor,
hidden_size,
dtype=torch.int32),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (intermediate_size_full
if load_full_w2 else intermediate_size_per_partition)
self.is_k_full = (not self.actorder) or (
intermediate_size_per_partition == intermediate_size_full)
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = intermediate_size // self.group_size
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w13,
2 * intermediate_size,
dtype=params_dtype),
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
@ -316,6 +337,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
@ -335,18 +357,18 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
@ -364,7 +386,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
@ -422,24 +444,55 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device
# when running models with grouped act order,
# resort to g_idx values provided in checkpoint
if self.actorder == "group":
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_weight_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(
layer.w2_weight_g_idx[e]).to(torch.int32)
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
layer.w13_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
@ -511,9 +564,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
)
is_k_full=self.is_k_full)

View File

@ -62,7 +62,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
**kwargs):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits

View File

@ -52,7 +52,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
int8_dtype = torch.int8
@ -64,26 +64,29 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=int8_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=int8_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
2 * intermediate_size,
dtype=torch.float32),
w13_scale = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)

View File

@ -386,8 +386,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@ -402,30 +402,34 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1 and intermediate_size % block_k != 0):
if (tp_size > 1
and intermediate_size_per_partition % block_k != 0):
# Required by row parallel
raise ValueError(f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}.")
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
@ -446,7 +450,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size + block_n - 1) // block_n),
2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
@ -456,7 +461,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,

View File

@ -317,7 +317,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
@ -326,7 +326,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
# Supports only sym for now (no zp)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = intermediate_size // self.quant_config.group_size
scales_size2 = (intermediate_size_per_partition //
self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
@ -342,7 +343,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size,
2 * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
@ -353,7 +354,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size // self.quant_config.pack_factor,
intermediate_size_per_partition //
self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
@ -365,7 +367,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size,
2 * intermediate_size_per_partition,
dtype=torch.half),
requires_grad=False,
)
@ -385,7 +387,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size // self.quant_config.pack_factor,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
@ -414,7 +417,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
@ -435,7 +438,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,

View File

@ -60,24 +60,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.static_input_scales = not self.input_quant.get("is_dynamic")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)