mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[Speculators][Speculative Decoding] Fix gpt-oss eagle3 accuracy issue (#25406)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
24fab45d96
commit
d5944d5146
@ -534,6 +534,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder)
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -660,6 +662,8 @@ def test_propose_tree(spec_token_tree):
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder)
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
|
||||
@ -1003,6 +1003,7 @@ class ModelConfig:
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
quant_method = quant_method if quant_method != "" else None
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
|
||||
@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
# Ensure draft_vocab_size is set
|
||||
# default to the base vocab size when absent
|
||||
if getattr(self.config, "draft_vocab_size", None) is None:
|
||||
base_vocab_size = getattr(self.config, "vocab_size", None)
|
||||
self.config.draft_vocab_size = base_vocab_size
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
|
||||
@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
# Ensure draft_vocab_size is set
|
||||
# default to the base vocab size when absent
|
||||
if getattr(self.config, "draft_vocab_size", None) is None:
|
||||
base_vocab_size = getattr(self.config, "vocab_size", None)
|
||||
self.config.draft_vocab_size = base_vocab_size
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadataBuilder
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
@ -77,6 +78,8 @@ class EagleProposer:
|
||||
self.is_multimodal_model = vllm_config.model_config \
|
||||
.is_multimodal_model
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
@ -117,7 +120,7 @@ class EagleProposer:
|
||||
with_numpy=True)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
|
||||
self.allowed_attn_types: tuple[type, ...]
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||
@ -190,10 +193,12 @@ class EagleProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
# Select the correct attention metadata builders for EAGLE layers.
|
||||
# Get the attention metadata builders once and reuse for later.
|
||||
builder = (self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None else
|
||||
self.attn_metadata_builder)
|
||||
attn_metadata = builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
@ -327,11 +332,9 @@ class EagleProposer:
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = attn_metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
attn_metadata = builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
@ -851,10 +854,24 @@ class EagleProposer:
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3" and \
|
||||
hasattr(target_language_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
if self.vllm_config.speculative_config.method != "eagle3":
|
||||
if hasattr(target_language_model, "lm_head"):
|
||||
logger.info(
|
||||
"Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
if (hasattr(self.model, "lm_head")
|
||||
and hasattr(target_language_model, "lm_head")
|
||||
and self.model.lm_head.weight.shape
|
||||
== target_language_model.lm_head.weight.shape):
|
||||
logger.info("Assuming the EAGLE head shares the same lm_head"
|
||||
" with the target model.")
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's lm_head will be loaded separately"
|
||||
" from the target model.")
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
@ -877,6 +894,31 @@ class EagleProposer:
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def _get_attention_metadata_builder(
|
||||
self) -> list[AttentionMetadataBuilder]:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers.")
|
||||
return builder
|
||||
|
||||
def validate_same_kv_cache_group(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
@ -1177,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
if (self.speculative_config
|
||||
and spec_decode_common_attn_metadata is None):
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if (self.drafter.attn_layer_names[0]
|
||||
in kv_cache_group_spec.layer_names):
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
else:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user