[Speculators][Speculative Decoding] Fix gpt-oss eagle3 accuracy issue (#25406)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-09-23 12:44:35 -07:00 committed by GitHub
parent 24fab45d96
commit d5944d5146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 79 additions and 17 deletions

View File

@ -534,6 +534,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
result = proposer.propose(target_token_ids=target_token_ids, result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
@ -660,6 +662,8 @@ def test_propose_tree(spec_token_tree):
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
# Setup inputs for the proposer. # Setup inputs for the proposer.
target_token_ids = torch.randint(0, target_token_ids = torch.randint(0,

View File

@ -1003,6 +1003,7 @@ class ModelConfig:
self.quantization = quantization_override self.quantization = quantization_override
break break
quant_method = quant_method if quant_method != "" else None
# Verify quantization configurations. # Verify quantization configurations.
if self.quantization is None: if self.quantization is None:
self.quantization = quant_method self.quantization = quant_method

View File

@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = vllm_config. \ self.config = vllm_config. \
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers( target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config, self.model = LlamaModel(vllm_config=vllm_config,

View File

@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = vllm_config. \ self.config = vllm_config. \
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers( target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)

View File

@ -9,6 +9,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
@ -77,6 +78,8 @@ class EagleProposer:
self.is_multimodal_model = vllm_config.model_config \ self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model .is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
@ -117,7 +120,7 @@ class EagleProposer:
with_numpy=True) with_numpy=True)
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] self.allowed_attn_types: tuple[type, ...]
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@ -190,10 +193,12 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # Select the correct attention metadata builders for EAGLE layers.
attn_metadata_builder = \ # Get the attention metadata builders once and reuse for later.
self.runner.attn_groups[0][0].get_metadata_builder() builder = (self._get_attention_metadata_builder()
attn_metadata = attn_metadata_builder.build_for_drafting( if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0) common_attn_metadata=common_attn_metadata, draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
@ -327,11 +332,9 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID) exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata_builder = \ attn_metadata = builder.build_for_drafting(
self.runner.attn_groups[0][0].get_metadata_builder() common_attn_metadata=common_attn_metadata,
attn_metadata = attn_metadata_builder\ draft_index=token_index + 1)
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
@ -851,10 +854,24 @@ class EagleProposer:
# share lm_head with the target model if needed # share lm_head with the target model if needed
# some model definition do not define lm_head explicitly # some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \ if self.vllm_config.speculative_config.method != "eagle3":
hasattr(target_language_model, "lm_head"): if hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.") logger.info(
self.model.lm_head = target_language_model.lm_head "Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape):
logger.info("Assuming the EAGLE head shares the same lm_head"
" with the target model.")
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model.")
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
@ -877,6 +894,31 @@ class EagleProposer:
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers.")
return builder
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
""" """

View File

@ -1177,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_seq_lens=encoder_seq_lens, encoder_seq_lens=encoder_seq_lens,
) )
if self.speculative_config and \ if (self.speculative_config
spec_decode_common_attn_metadata is None: and spec_decode_common_attn_metadata is None):
spec_decode_common_attn_metadata = common_attn_metadata if isinstance(self.drafter, EagleProposer):
if (self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names):
spec_decode_common_attn_metadata = common_attn_metadata
else:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial. # Prepare for cascade attention if enabled & beneficial.