From c1e4a4052d65d72d45e39db1edb6b7deb4ffd426 Mon Sep 17 00:00:00 2001 From: qizixi <22851944+zixi-qi@users.noreply.github.com> Date: Sat, 24 May 2025 02:45:34 -0700 Subject: [PATCH] [V1][Spec Decode] Support multi-layer eagle draft model (#18030) Signed-off-by: qizixi --- tests/v1/spec_decode/test_eagle.py | 3 +++ vllm/v1/spec_decode/eagle.py | 33 ++++++++++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 18 +++++++++++----- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7be1c5b89938e..b49ac45f3129b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -246,6 +246,9 @@ def test_propose(num_speculative_tokens): # Assign the mock to the proposer proposer.model = model_mock + # Assign draft attn_layer_names since load_model is not invoked + proposer.attn_layer_names = ["layer.0"] + # Create input tensors cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], dtype=torch.int32, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 876e1ddd14a6c..971b06758c214 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,6 +12,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel @@ -150,6 +151,11 @@ class EagleProposer: else: raise ValueError(f"Unsupported method: {self.method}") + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -159,7 +165,7 @@ class EagleProposer: self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - with set_forward_context(attn_metadata, + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): ret_hidden_states = self.model( @@ -245,7 +251,7 @@ class EagleProposer: self.hidden_states[:batch_size] = hidden_states # Run the model. - with set_forward_context(attn_metadata, + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( @@ -318,8 +324,8 @@ class EagleProposer: draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = next(iter(draft_attn_layer_names)) + + self.attn_layer_names = list(draft_attn_layer_names) # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: @@ -355,6 +361,25 @@ class EagleProposer: self.hidden_states[:num_tokens], ) + def validate_same_kv_cache_group(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Validate that all eagle layers belong to the same KVCacheGroup. + Need this assumption to ensure all eagle layers can use the + same AttentionMetadata. + May extend to multiple AttentionMetadata in the future. + """ + kv_cache_groups: dict[str, int] = {} + for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + kv_cache_groups[layer_name] = id + assert len( + set([ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ]) + ) == 1, "All eagle layers should belong to the same kv cache group" + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5120495dbb9b7..aa47ac253bb93 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = async_tensor_h2d(next_token_ids, - dtype=torch.int32, - target_device=self.device, - pin_memory=True) - eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + eagle_attn_metadata = attn_metadata[ + self.drafter.attn_layer_names[0]] # NOTE: deepseek_mtp uses MLA which does not have `block_table` if hasattr(eagle_attn_metadata, "block_table"): @@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # KV cache specs. raise ValueError("Unknown KV cache spec type.") + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # validate all draft model layers belong to the same kv cache + # group + self.drafter.validate_same_kv_cache_group(kv_cache_config) + bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context,