mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
EAGLE Support DP>1 (#26086)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi <remi@mistral.ai>
This commit is contained in:
parent
f242cfcdd5
commit
12c007e288
@ -192,6 +192,7 @@ steps:
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||
@ -1116,6 +1117,7 @@ steps:
|
||||
# https://github.com/NVIDIA/nccl/issues/1838
|
||||
- export NCCL_CUMEM_HOST_ENABLE=0
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
|
||||
77
tests/v1/distributed/test_eagle_dp.py
Normal file
77
tests/v1/distributed/test_eagle_dp.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_eagle_dp():
|
||||
target_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=target_model,
|
||||
tokenizer_mode="auto",
|
||||
enforce_eager=False,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp", # ray takes more time
|
||||
trust_remote_code=True,
|
||||
max_model_len=16384,
|
||||
)
|
||||
|
||||
eagle_engine_args = replace(
|
||||
engine_args,
|
||||
speculative_config={
|
||||
"model": draft_model,
|
||||
"method": "eagle",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
|
||||
prompt = "This is a test of data parallel with eagle"
|
||||
num_expected_tokens = 100
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=num_expected_tokens,
|
||||
max_tokens=num_expected_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=RequestOutputKind.FINAL_ONLY,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
async def generate_with_timeout(given_engine: AsyncLLM):
|
||||
async for out in given_engine.generate(
|
||||
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
token_ids = out.outputs[0].token_ids
|
||||
assert len(token_ids) == num_expected_tokens
|
||||
return token_ids
|
||||
|
||||
async def engine_create_and_generate(engine_args: AsyncEngineArgs):
|
||||
async with AsyncExitStack() as after:
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
token_ids = await asyncio.wait_for(
|
||||
generate_with_timeout(engine), timeout=30
|
||||
)
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
return token_ids
|
||||
|
||||
token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
|
||||
token_ids_no_eagle = await engine_create_and_generate(engine_args)
|
||||
|
||||
# Test for correctness
|
||||
assert token_ids_with_eagle == token_ids_no_eagle
|
||||
@ -40,6 +40,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -65,6 +66,7 @@ class EagleProposer:
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
@ -271,15 +273,24 @@ class EagleProposer:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
num_tokens_padded=num_tokens,
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
if (
|
||||
self.use_cuda_graph
|
||||
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
|
||||
and num_tokens_dp_padded
|
||||
<= self.compilation_config.max_cudagraph_capture_size
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
@ -303,6 +314,7 @@ class EagleProposer:
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
@ -365,15 +377,23 @@ class EagleProposer:
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=batch_size,
|
||||
num_tokens_padded=batch_size,
|
||||
)
|
||||
|
||||
if (
|
||||
self.use_cuda_graph
|
||||
and batch_size <= self.compilation_config.max_cudagraph_capture_size
|
||||
and batch_size_dp_padded
|
||||
<= self.compilation_config.max_cudagraph_capture_size
|
||||
):
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
input_batch_size = batch_size_dp_padded
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
if batch_size_across_dp is not None:
|
||||
batch_size_across_dp[self.dp_rank] = input_batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
@ -474,6 +494,7 @@ class EagleProposer:
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
@ -1116,36 +1137,56 @@ class EagleProposer:
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs=True,
|
||||
is_graph_capturing=False,
|
||||
) -> None:
|
||||
# Determine if CUDA graphs should be used for this run.
|
||||
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
|
||||
if (
|
||||
cudagraphs_enabled
|
||||
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
|
||||
):
|
||||
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=(
|
||||
CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
|
||||
),
|
||||
# FIXME: when using tree-based specdec, adjust number of forward-passes
|
||||
# according to the depth of the tree.
|
||||
for fwd_idx in range(
|
||||
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
if fwd_idx <= 1:
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
num_tokens_padded=num_tokens,
|
||||
)
|
||||
if (
|
||||
cudagraphs_enabled
|
||||
and num_tokens_dp_padded
|
||||
<= self.compilation_config.max_cudagraph_capture_size
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_tokens_dp_padded
|
||||
)
|
||||
else:
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self._get_positions(num_tokens),
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
|
||||
if cudagraphs_enabled
|
||||
else CUDAGraphMode.NONE,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self._get_positions(num_input_tokens),
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
@ -1211,6 +1252,28 @@ class EagleProposer:
|
||||
== 1
|
||||
), "All eagle layers should belong to the same kv cache group"
|
||||
|
||||
def _pad_batch_across_dp(
|
||||
self,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
allow_dp_padding=self.use_cuda_graph,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=None,
|
||||
num_scheduled_tokens_per_request=None,
|
||||
)
|
||||
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
num_tokens_dp_padded = num_tokens_padded
|
||||
if num_toks_across_dp is not None:
|
||||
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
|
||||
return num_tokens_dp_padded, num_toks_across_dp
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
|
||||
@ -3746,6 +3746,7 @@ class GPUModelRunner(
|
||||
create_mixed_batch: bool = False,
|
||||
remove_lora: bool = True,
|
||||
activate_lora: bool = False,
|
||||
is_graph_capturing: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
@ -3981,7 +3982,7 @@ class GPUModelRunner(
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
use_cudagraphs = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
|
||||
cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
|
||||
and not self.speculative_config.enforce_eager
|
||||
)
|
||||
|
||||
@ -3995,6 +3996,7 @@ class GPUModelRunner(
|
||||
self.drafter.dummy_run(
|
||||
num_tokens,
|
||||
use_cudagraphs=use_cudagraphs,
|
||||
is_graph_capturing=is_graph_capturing,
|
||||
)
|
||||
|
||||
# This is necessary to avoid blocking DP.
|
||||
@ -4427,6 +4429,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user