[Fix][Spec Decode] Fix llama4 draft loading with different quantization (#27136)

Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
Zebing Lin 2025-10-21 02:19:00 -04:00 committed by GitHub
parent f381cf2302
commit be4445072c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,6 +60,10 @@ class LlamaModel(nn.Module):
prefix=maybe_prefix(prefix, "embed_tokens"),
)
# Temporarily modify vllm_config.quant_config for draft model layers
original_quant_config = vllm_config.quant_config
vllm_config.quant_config = quant_config
try:
self.layers = nn.ModuleList(
[
Llama4DecoderLayer(
@ -70,6 +74,9 @@ class LlamaModel(nn.Module):
for i in range(self.config.num_hidden_layers)
]
)
finally:
# Restore original quant_config
vllm_config.quant_config = original_quant_config
self.fc = torch.nn.Linear(
self.config.hidden_size * 2, self.config.hidden_size, bias=False
)