[Speculators][Speculative Decoding] Fix gpt-oss eagle3 accuracy issue (#25406)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-09-23 12:44:35 -07:00 committed by GitHub
parent 24fab45d96
commit d5944d5146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 79 additions and 17 deletions

View File

@ -534,6 +534,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
@ -660,6 +662,8 @@ def test_propose_tree(spec_token_tree):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,

View File

@ -1003,6 +1003,7 @@ class ModelConfig:
self.quantization = quantization_override
break
quant_method = quant_method if quant_method != "" else None
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method

View File

@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,

View File

@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)

View File

@ -9,6 +9,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
@ -77,6 +78,8 @@ class EagleProposer:
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
@ -117,7 +120,7 @@ class EagleProposer:
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
self.allowed_attn_types: tuple[type, ...]
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@ -190,10 +193,12 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata_builder = \
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder.build_for_drafting(
# Select the correct attention metadata builders for EAGLE layers.
# Get the attention metadata builders once and reuse for later.
builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
@ -327,11 +332,9 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata_builder = \
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
@ -851,10 +854,24 @@ class EagleProposer:
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
if self.vllm_config.speculative_config.method != "eagle3":
if hasattr(target_language_model, "lm_head"):
logger.info(
"Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape):
logger.info("Assuming the EAGLE head shares the same lm_head"
" with the target model.")
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model.")
@torch.inference_mode()
def dummy_run(
@ -877,6 +894,31 @@ class EagleProposer:
inputs_embeds=inputs_embeds,
)
def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers.")
return builder
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""

View File

@ -1177,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_seq_lens=encoder_seq_lens,
)
if self.speculative_config and \
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
if (self.speculative_config
and spec_decode_common_attn_metadata is None):
if isinstance(self.drafter, EagleProposer):
if (self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names):
spec_decode_common_attn_metadata = common_attn_metadata
else:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.