EAGLE Support DP>1 (#26086)

Signed-off-by: Rémi Delacourt <remi@mistral.ai>
Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com>
Signed-off-by: remi <remi@mistral.ai>
This commit is contained in:
Rémi Delacourt 2025-11-25 08:32:21 +01:00 committed by GitHub
parent f242cfcdd5
commit 12c007e288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 176 additions and 31 deletions

View File

@ -192,6 +192,7 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
@ -1116,6 +1117,7 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from contextlib import AsyncExitStack
from dataclasses import replace
import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
DP_SIZE = int(os.getenv("DP_SIZE", 2))
@pytest.mark.asyncio
async def test_run_eagle_dp():
target_model = "meta-llama/Llama-3.1-8B-Instruct"
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
engine_args = AsyncEngineArgs(
model=target_model,
tokenizer_mode="auto",
enforce_eager=False,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp", # ray takes more time
trust_remote_code=True,
max_model_len=16384,
)
eagle_engine_args = replace(
engine_args,
speculative_config={
"model": draft_model,
"method": "eagle",
"num_speculative_tokens": 3,
},
)
prompt = "This is a test of data parallel with eagle"
num_expected_tokens = 100
sampling_params = SamplingParams(
min_tokens=num_expected_tokens,
max_tokens=num_expected_tokens,
ignore_eos=True,
output_kind=RequestOutputKind.FINAL_ONLY,
temperature=0,
)
async def generate_with_timeout(given_engine: AsyncLLM):
async for out in given_engine.generate(
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
):
token_ids = out.outputs[0].token_ids
assert len(token_ids) == num_expected_tokens
return token_ids
async def engine_create_and_generate(engine_args: AsyncEngineArgs):
async with AsyncExitStack() as after:
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
token_ids = await asyncio.wait_for(
generate_with_timeout(engine), timeout=30
)
assert not engine.output_processor.has_unfinished_requests()
return token_ids
token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
token_ids_no_eagle = await engine_create_and_generate(engine_args)
# Test for correctness
assert token_ids_with_eagle == token_ids_no_eagle

View File

@ -40,6 +40,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger = init_logger(__name__)
@ -65,6 +66,7 @@ class EagleProposer:
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
@ -271,15 +273,24 @@ class EagleProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_tokens,
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
if (
self.use_cuda_graph
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
and num_tokens_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
@ -303,6 +314,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@ -365,15 +377,23 @@ class EagleProposer:
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size,
num_tokens_padded=batch_size,
)
if (
self.use_cuda_graph
and batch_size <= self.compilation_config.max_cudagraph_capture_size
and batch_size_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
input_batch_size = batch_size
input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
@ -474,6 +494,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@ -1116,36 +1137,56 @@ class EagleProposer:
self,
num_tokens: int,
use_cudagraphs=True,
is_graph_capturing=False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
if (
cudagraphs_enabled
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
):
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=(
CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
),
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_tokens,
)
if (
cudagraphs_enabled
and num_tokens_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens_dp_padded
)
else:
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
self.model(
input_ids=input_ids,
positions=self._get_positions(num_tokens),
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
if cudagraphs_enabled
else CUDAGraphMode.NONE,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
self.model(
input_ids=input_ids,
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
@ -1211,6 +1252,28 @@ class EagleProposer:
== 1
), "All eagle layers should belong to the same kv cache group"
def _pad_batch_across_dp(
self,
num_tokens_unpadded: int,
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=self.use_cuda_graph,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
)
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
return num_tokens_dp_padded, num_toks_across_dp
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage

View File

@ -3746,6 +3746,7 @@ class GPUModelRunner(
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@ -3981,7 +3982,7 @@ class GPUModelRunner(
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
use_cudagraphs = (
cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
@ -3995,6 +3996,7 @@ class GPUModelRunner(
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
)
# This is necessary to avoid blocking DP.
@ -4427,6 +4429,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)