mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
parent
a65f46be5e
commit
9f9c38c392
@ -6,11 +6,21 @@ import torch
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_path",
|
"model_path",
|
||||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
|
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||||
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
|
||||||
def test_llama(vllm_runner, example_prompts, model_path):
|
def test_llama(vllm_runner, example_prompts, model_path):
|
||||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
max_tokens=20)
|
max_tokens=20)
|
||||||
print(vllm_outputs)
|
print(vllm_outputs)
|
||||||
assert vllm_outputs
|
assert vllm_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_path",
|
||||||
|
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
|
||||||
|
def test_qwen(vllm_runner, example_prompts, model_path):
|
||||||
|
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens=20)
|
||||||
|
print(vllm_outputs)
|
||||||
|
assert vllm_outputs
|
||||||
|
|||||||
@ -3175,10 +3175,19 @@ class SpeculativeConfig:
|
|||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{self.disable_by_batch_size=}")
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
if self.method == "eagle3" and self.target_model_config and \
|
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||||
"llama" not in self.target_model_config.hf_text_config.model_type:
|
|
||||||
|
eagle3_target_supported = ["llama"]
|
||||||
|
if self.draft_model_config and isinstance(
|
||||||
|
self.draft_model_config.hf_config, SpeculatorsConfig):
|
||||||
|
eagle3_target_supported.append("qwen")
|
||||||
|
|
||||||
|
if self.method == "eagle3" and self.target_model_config and not any(
|
||||||
|
supported_model in
|
||||||
|
self.target_model_config.hf_text_config.model_type
|
||||||
|
for supported_model in eagle3_target_supported):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Eagle3 is only supported for Llama models. "
|
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||||
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -330,6 +330,8 @@ class Qwen2Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.aux_hidden_state_layers: tuple[int] = tuple()
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -350,18 +352,25 @@ class Qwen2Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
|
||||||
hidden_states, residual = layer(
|
aux_hidden_states = []
|
||||||
positions,
|
for idx, layer in enumerate(
|
||||||
hidden_states,
|
self.layers[self.start_layer:self.end_layer]):
|
||||||
residual,
|
if idx in self.aux_hidden_state_layers:
|
||||||
)
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
"residual": residual
|
||||||
})
|
})
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) > 0:
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
|||||||
@ -288,6 +288,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||||
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user