diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6e88664f007d..393a4d964ee3 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -41,7 +41,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -234,11 +233,11 @@ class EagleProposer: assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ - ubatch_id - ] + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0 ) @@ -1076,7 +1075,7 @@ class EagleProposer: inputs_embeds=inputs_embeds, ) - def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: + def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. Returns: