[BUGFIX] Fix crash in Eagle Speculative Decoding models when exceedin… (#24662)

Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
AlonKejzman 2025-09-25 18:40:14 +03:00 committed by yewentao256
parent 2655d7ab83
commit 252a0ff8c3

View File

@ -2310,7 +2310,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.use_eagle() and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
effective_drafter_max_model_len = self.max_model_len
if effective_drafter_max_model_len is None:
effective_drafter_max_model_len = self.model_config.max_model_len
if (self.speculative_config
and self.speculative_config.draft_model_config is not None
and self.speculative_config.draft_model_config.max_model_len
is not None):
effective_drafter_max_model_len = (
self.speculative_config.draft_model_config.max_model_len)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.seq_lens.max() +
self.speculative_config.num_speculative_tokens
<= effective_drafter_max_model_len)
if use_padded_batch_for_eagle and input_fits_in_drafter:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
@ -2328,7 +2341,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits, hidden_states,
num_scheduled_tokens)
if self.speculative_config and not use_padded_batch_for_eagle:
if (self.speculative_config and not use_padded_batch_for_eagle
and input_fits_in_drafter):
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)