mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 08:10:34 +08:00
[bugfix][quantization] fix quark qwen3 kv_cache quantization (#30308)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
This commit is contained in:
parent
7d80c73d42
commit
06462392e4
@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert loaded_weight.numel() == 1, (
|
||||
f"KV scale numel {loaded_weight.numel()} != 1"
|
||||
)
|
||||
loaded_weight = loaded_weight.squeeze()
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user