vllm/tests/v1/distributed/test_eagle_dp.py
Rémi Delacourt 12c007e288
EAGLE Support DP>1 (#26086)
Signed-off-by: Rémi Delacourt <remi@mistral.ai>
Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com>
Signed-off-by: remi <remi@mistral.ai>
2025-11-25 07:32:21 +00:00

78 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from contextlib import AsyncExitStack
from dataclasses import replace
import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
DP_SIZE = int(os.getenv("DP_SIZE", 2))
@pytest.mark.asyncio
async def test_run_eagle_dp():
target_model = "meta-llama/Llama-3.1-8B-Instruct"
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
engine_args = AsyncEngineArgs(
model=target_model,
tokenizer_mode="auto",
enforce_eager=False,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp", # ray takes more time
trust_remote_code=True,
max_model_len=16384,
)
eagle_engine_args = replace(
engine_args,
speculative_config={
"model": draft_model,
"method": "eagle",
"num_speculative_tokens": 3,
},
)
prompt = "This is a test of data parallel with eagle"
num_expected_tokens = 100
sampling_params = SamplingParams(
min_tokens=num_expected_tokens,
max_tokens=num_expected_tokens,
ignore_eos=True,
output_kind=RequestOutputKind.FINAL_ONLY,
temperature=0,
)
async def generate_with_timeout(given_engine: AsyncLLM):
async for out in given_engine.generate(
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
):
token_ids = out.outputs[0].token_ids
assert len(token_ids) == num_expected_tokens
return token_ids
async def engine_create_and_generate(engine_args: AsyncEngineArgs):
async with AsyncExitStack() as after:
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
token_ids = await asyncio.wait_for(
generate_with_timeout(engine), timeout=30
)
assert not engine.output_processor.has_unfinished_requests()
return token_ids
token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
token_ids_no_eagle = await engine_create_and_generate(engine_args)
# Test for correctness
assert token_ids_with_eagle == token_ids_no_eagle