mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:25:01 +08:00
EAGLE Support DP>1 (#26086)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi <remi@mistral.ai>
This commit is contained in:
parent
f242cfcdd5
commit
12c007e288
@ -192,6 +192,7 @@ steps:
|
|||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||||
@ -1116,6 +1117,7 @@ steps:
|
|||||||
# https://github.com/NVIDIA/nccl/issues/1838
|
# https://github.com/NVIDIA/nccl/issues/1838
|
||||||
- export NCCL_CUMEM_HOST_ENABLE=0
|
- export NCCL_CUMEM_HOST_ENABLE=0
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
|
|||||||
77
tests/v1/distributed/test_eagle_dp.py
Normal file
77
tests/v1/distributed/test_eagle_dp.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
|
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_eagle_dp():
|
||||||
|
target_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=target_model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
enforce_eager=False,
|
||||||
|
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||||
|
data_parallel_size=DP_SIZE,
|
||||||
|
data_parallel_backend="mp", # ray takes more time
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=16384,
|
||||||
|
)
|
||||||
|
|
||||||
|
eagle_engine_args = replace(
|
||||||
|
engine_args,
|
||||||
|
speculative_config={
|
||||||
|
"model": draft_model,
|
||||||
|
"method": "eagle",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "This is a test of data parallel with eagle"
|
||||||
|
num_expected_tokens = 100
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
min_tokens=num_expected_tokens,
|
||||||
|
max_tokens=num_expected_tokens,
|
||||||
|
ignore_eos=True,
|
||||||
|
output_kind=RequestOutputKind.FINAL_ONLY,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_with_timeout(given_engine: AsyncLLM):
|
||||||
|
async for out in given_engine.generate(
|
||||||
|
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
|
||||||
|
):
|
||||||
|
token_ids = out.outputs[0].token_ids
|
||||||
|
assert len(token_ids) == num_expected_tokens
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
async def engine_create_and_generate(engine_args: AsyncEngineArgs):
|
||||||
|
async with AsyncExitStack() as after:
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
|
token_ids = await asyncio.wait_for(
|
||||||
|
generate_with_timeout(engine), timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
|
||||||
|
token_ids_no_eagle = await engine_create_and_generate(engine_args)
|
||||||
|
|
||||||
|
# Test for correctness
|
||||||
|
assert token_ids_with_eagle == token_ids_no_eagle
|
||||||
@ -40,6 +40,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -65,6 +66,7 @@ class EagleProposer:
|
|||||||
self.dtype = vllm_config.model_config.dtype
|
self.dtype = vllm_config.model_config.dtype
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||||
@ -271,15 +273,24 @@ class EagleProposer:
|
|||||||
assert draft_indexer_metadata is not None
|
assert draft_indexer_metadata is not None
|
||||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||||
|
|
||||||
|
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||||
|
num_tokens_unpadded=num_tokens,
|
||||||
|
num_tokens_padded=num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
if (
|
if (
|
||||||
self.use_cuda_graph
|
self.use_cuda_graph
|
||||||
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
|
and num_tokens_dp_padded
|
||||||
|
<= self.compilation_config.max_cudagraph_capture_size
|
||||||
):
|
):
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens_dp_padded
|
||||||
|
if num_tokens_across_dp is not None:
|
||||||
|
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||||
|
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self._set_positions(num_tokens, target_positions)
|
self._set_positions(num_tokens, target_positions)
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
@ -303,6 +314,7 @@ class EagleProposer:
|
|||||||
per_layer_attn_metadata,
|
per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
@ -365,15 +377,23 @@ class EagleProposer:
|
|||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
|
||||||
|
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
|
||||||
|
num_tokens_unpadded=batch_size,
|
||||||
|
num_tokens_padded=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.use_cuda_graph
|
self.use_cuda_graph
|
||||||
and batch_size <= self.compilation_config.max_cudagraph_capture_size
|
and batch_size_dp_padded
|
||||||
|
<= self.compilation_config.max_cudagraph_capture_size
|
||||||
):
|
):
|
||||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size_dp_padded
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
if batch_size_across_dp is not None:
|
||||||
|
batch_size_across_dp[self.dp_rank] = input_batch_size
|
||||||
|
|
||||||
common_attn_metadata.num_actual_tokens = batch_size
|
common_attn_metadata.num_actual_tokens = batch_size
|
||||||
common_attn_metadata.max_query_len = 1
|
common_attn_metadata.max_query_len = 1
|
||||||
@ -474,6 +494,7 @@ class EagleProposer:
|
|||||||
per_layer_attn_metadata,
|
per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size,
|
num_tokens=input_batch_size,
|
||||||
|
num_tokens_across_dp=batch_size_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
@ -1116,36 +1137,56 @@ class EagleProposer:
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
use_cudagraphs=True,
|
use_cudagraphs=True,
|
||||||
|
is_graph_capturing=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Determine if CUDA graphs should be used for this run.
|
# Determine if CUDA graphs should be used for this run.
|
||||||
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
|
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
|
||||||
if (
|
|
||||||
cudagraphs_enabled
|
|
||||||
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
|
|
||||||
):
|
|
||||||
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
|
||||||
|
|
||||||
with set_forward_context(
|
# FIXME: when using tree-based specdec, adjust number of forward-passes
|
||||||
None,
|
# according to the depth of the tree.
|
||||||
self.vllm_config,
|
for fwd_idx in range(
|
||||||
num_tokens=num_tokens,
|
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||||
cudagraph_runtime_mode=(
|
|
||||||
CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
if self.supports_mm_inputs:
|
if fwd_idx <= 1:
|
||||||
input_ids = None
|
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
num_tokens_unpadded=num_tokens,
|
||||||
else:
|
num_tokens_padded=num_tokens,
|
||||||
input_ids = self.input_ids[:num_tokens]
|
)
|
||||||
inputs_embeds = None
|
if (
|
||||||
|
cudagraphs_enabled
|
||||||
|
and num_tokens_dp_padded
|
||||||
|
<= self.compilation_config.max_cudagraph_capture_size
|
||||||
|
):
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
num_tokens_dp_padded
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_input_tokens = num_tokens_dp_padded
|
||||||
|
if num_tokens_across_dp is not None:
|
||||||
|
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||||
|
|
||||||
self.model(
|
with set_forward_context(
|
||||||
input_ids=input_ids,
|
None,
|
||||||
positions=self._get_positions(num_tokens),
|
self.vllm_config,
|
||||||
hidden_states=self.hidden_states[:num_tokens],
|
num_tokens=num_input_tokens,
|
||||||
inputs_embeds=inputs_embeds,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
)
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
|
||||||
|
if cudagraphs_enabled
|
||||||
|
else CUDAGraphMode.NONE,
|
||||||
|
):
|
||||||
|
if self.supports_mm_inputs:
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=self._get_positions(num_input_tokens),
|
||||||
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||||
"""Find and return the attention metadata builders for EAGLE layers.
|
"""Find and return the attention metadata builders for EAGLE layers.
|
||||||
@ -1211,6 +1252,28 @@ class EagleProposer:
|
|||||||
== 1
|
== 1
|
||||||
), "All eagle layers should belong to the same kv cache group"
|
), "All eagle layers should belong to the same kv cache group"
|
||||||
|
|
||||||
|
def _pad_batch_across_dp(
|
||||||
|
self,
|
||||||
|
num_tokens_unpadded: int,
|
||||||
|
num_tokens_padded: int,
|
||||||
|
) -> tuple[int, torch.Tensor]:
|
||||||
|
# TODO(Flechman): support DBO ubatching
|
||||||
|
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
|
||||||
|
num_tokens_unpadded=num_tokens_unpadded,
|
||||||
|
parallel_config=self.vllm_config.parallel_config,
|
||||||
|
allow_microbatching=False,
|
||||||
|
allow_dp_padding=self.use_cuda_graph,
|
||||||
|
num_tokens_padded=num_tokens_padded,
|
||||||
|
uniform_decode=None,
|
||||||
|
num_scheduled_tokens_per_request=None,
|
||||||
|
)
|
||||||
|
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
|
||||||
|
|
||||||
|
num_tokens_dp_padded = num_tokens_padded
|
||||||
|
if num_toks_across_dp is not None:
|
||||||
|
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
|
||||||
|
return num_tokens_dp_padded, num_toks_across_dp
|
||||||
|
|
||||||
|
|
||||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||||
# to sample the draft tokens. We will use this after we find a way to manage
|
# to sample the draft tokens. We will use this after we find a way to manage
|
||||||
|
|||||||
@ -3746,6 +3746,7 @@ class GPUModelRunner(
|
|||||||
create_mixed_batch: bool = False,
|
create_mixed_batch: bool = False,
|
||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
activate_lora: bool = False,
|
activate_lora: bool = False,
|
||||||
|
is_graph_capturing: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Run a dummy forward pass to warm up/profile run or capture the
|
Run a dummy forward pass to warm up/profile run or capture the
|
||||||
@ -3981,7 +3982,7 @@ class GPUModelRunner(
|
|||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
use_cudagraphs = (
|
use_cudagraphs = (
|
||||||
cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
|
||||||
and not self.speculative_config.enforce_eager
|
and not self.speculative_config.enforce_eager
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3995,6 +3996,7 @@ class GPUModelRunner(
|
|||||||
self.drafter.dummy_run(
|
self.drafter.dummy_run(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
use_cudagraphs=use_cudagraphs,
|
use_cudagraphs=use_cudagraphs,
|
||||||
|
is_graph_capturing=is_graph_capturing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is necessary to avoid blocking DP.
|
# This is necessary to avoid blocking DP.
|
||||||
@ -4427,6 +4429,7 @@ class GPUModelRunner(
|
|||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False,
|
remove_lora=False,
|
||||||
activate_lora=activate_lora,
|
activate_lora=activate_lora,
|
||||||
|
is_graph_capturing=True,
|
||||||
)
|
)
|
||||||
self.maybe_remove_all_loras(self.lora_config)
|
self.maybe_remove_all_loras(self.lora_config)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user