[BugFix][Spec Decode] Fix hidden size mismatch between target and eagle head (#17740)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-05-06 19:51:26 -07:00 committed by GitHub
parent 950b71186f
commit 8d84d836d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,23 +28,25 @@ class EagleProposer:
device: torch.device,
):
self.vllm_config = vllm_config
self.method = self.vllm_config.speculative_config.method
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.dtype = vllm_config.model_config.dtype
self.max_num_tokens = vllm_config.scheduler_config \
.max_num_batched_tokens
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = (
self.speculative_config.num_speculative_tokens)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@ -56,7 +58,6 @@ class EagleProposer:
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@ -131,7 +132,6 @@ class EagleProposer:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
with set_forward_context(attn_metadata,
@ -209,7 +209,6 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
# Run the model.