[Fix][Spec Decode] Fix llama4 draft loading with different quantization (#27136)

Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
Zebing Lin 2025-10-21 02:19:00 -04:00 committed by GitHub
parent f381cf2302
commit be4445072c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,16 +60,23 @@ class LlamaModel(nn.Module):
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList(
[
Llama4DecoderLayer(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
)
for i in range(self.config.num_hidden_layers)
]
)
# Temporarily modify vllm_config.quant_config for draft model layers
original_quant_config = vllm_config.quant_config
vllm_config.quant_config = quant_config
try:
self.layers = nn.ModuleList(
[
Llama4DecoderLayer(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
)
for i in range(self.config.num_hidden_layers)
]
)
finally:
# Restore original quant_config
vllm_config.quant_config = original_quant_config
self.fc = torch.nn.Linear(
self.config.hidden_size * 2, self.config.hidden_size, bias=False
)