mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
Add: SupportsEagle3 interface for explicit EAGLE3 support (#22642)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
parent
e5d3d63c42
commit
5a4b4b3729
@ -3,12 +3,20 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||
def test_llama(vllm_runner, example_prompts, model_path):
|
||||
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
|
||||
# Set environment variable for V1 engine serialization
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||
assert eagle3_supported
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
|
||||
def test_qwen(vllm_runner, example_prompts, model_path):
|
||||
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
|
||||
# Set environment variable for V1 engine serialization
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||
assert eagle3_supported
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
|
||||
@ -823,3 +823,56 @@ def supports_v0_only(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
|
||||
return getattr(model, "supports_v0_only", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEagle3(Protocol):
|
||||
"""The interface required for models that support
|
||||
EAGLE3 speculative decoding."""
|
||||
|
||||
supports_eagle3: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports EAGLE3
|
||||
speculative decoding.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
"""
|
||||
Set which layers should output auxiliary
|
||||
hidden states for EAGLE3.
|
||||
|
||||
Args:
|
||||
layers: Tuple of layer indices that should output auxiliary
|
||||
hidden states.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the layer indices that should output auxiliary hidden states
|
||||
for EAGLE3.
|
||||
|
||||
Returns:
|
||||
Tuple of layer indices for auxiliary hidden state outputs.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]:
|
||||
...
|
||||
|
||||
|
||||
def supports_eagle3(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
|
||||
return isinstance(model, SupportsEagle3)
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -463,7 +463,7 @@ class LlamaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
|
||||
@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model):
|
||||
decoder_layer_type=Qwen3DecoderLayer)
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
@ -1981,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers())
|
||||
if supports_eagle3(self.model):
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers())
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Model does not support EAGLE3 interface but "
|
||||
"aux_hidden_state_outputs was requested")
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user