[bugfix][quantization] fix quark qwen3 kv_cache quantization (#30308)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
This commit is contained in:
haoyangli-amd 2025-12-10 11:24:12 +08:00 committed by GitHub
parent 7d80c73d42
commit 06462392e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert loaded_weight.numel() == 1, (
f"KV scale numel {loaded_weight.numel()} != 1"
)
loaded_weight = loaded_weight.squeeze()
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: