[Bugfix]: Fix glm46 awq marlin moe wna16 compatibility (#30210)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
Dongjie Zou 2025-12-09 07:20:21 -05:00 committed by GitHub
parent 03416eada6
commit 1166c31cc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 4 deletions

View File

@ -895,6 +895,48 @@ def get_moe_configs(
return None
def _ensure_block_size_k_divisible(
size_k: int, block_size_k: int, group_size: int
) -> int:
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
Args:
size_k: The size_k dimension that must be divisible by result.
block_size_k: Preferred block size (will be adjusted if needed).
group_size: The result must be divisible by this.
Returns:
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
"""
# Fast path: already valid
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
return block_size_k
# Find the largest value that:
# 1. Divides size_k (size_k % candidate == 0)
# 2. Is divisible by group_size (candidate % group_size == 0)
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
#
# Strategy: Search from min(block_size_k, size_k) down to group_size,
# stepping by group_size to ensure divisibility by group_size
max_search = min(block_size_k, size_k)
start = (max_search // group_size) * group_size
for candidate in range(start, group_size - 1, -group_size):
if size_k % candidate == 0:
return candidate
# Fallback: if group_size divides size_k, use it
# This should always be true with correct group_size configuration
if size_k % group_size == 0:
return group_size
# This should not happen with correct group_size, but ensure divisibility
return size_k
def get_moe_wna16_block_config(
config: dict[str, int],
use_moe_wna16_cuda: bool,
@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
# at the same time.
block_size_n = 1024
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}

View File

@ -60,7 +60,7 @@ class MoeWNA16Config(QuantizationConfig):
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq":
elif self.linear_quant_method in ("awq", "awq_marlin"):
capability_tuple = current_platform.get_device_capability()
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
@ -107,7 +107,7 @@ class MoeWNA16Config(QuantizationConfig):
if linear_quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = []
elif linear_quant_method == "awq":
elif linear_quant_method in ("awq", "awq_marlin"):
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
@ -184,7 +184,7 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
elif self.linear_quant_method == "awq":
elif self.linear_quant_method in ("awq", "awq_marlin"):
if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size
):
@ -468,7 +468,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
# awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")