mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:34:56 +08:00
468 lines
18 KiB
Python
468 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from tests.utils import create_new_process_for_each_test
|
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|
from vllm.config import (
|
|
CompilationConfig,
|
|
CompilationMode,
|
|
CUDAGraphMode,
|
|
ParallelConfig,
|
|
SchedulerConfig,
|
|
VllmConfig,
|
|
)
|
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
|
from vllm.platforms import current_platform
|
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
|
|
|
|
|
# Helper MLP for testing
|
|
class SimpleMLP(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(10, 10)
|
|
self.fc2 = nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.fc1(x))
|
|
|
|
|
|
def _create_vllm_config(
|
|
compilation_config: CompilationConfig,
|
|
max_num_seqs: int = 8,
|
|
lora_config: bool = False,
|
|
) -> MagicMock:
|
|
mock_config = MagicMock(spec=VllmConfig)
|
|
mock_config.compilation_config = compilation_config
|
|
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
|
mock_config.parallel_config = ParallelConfig()
|
|
mock_config.speculative_config = None # No speculative decoding
|
|
if not lora_config:
|
|
mock_config.lora_config = None
|
|
# Mimic the behavior of VllmConfig.__post_init__()
|
|
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
|
compilation_config.set_splitting_ops_for_v1()
|
|
|
|
# mimic VllmConfig.__post_init__
|
|
if compilation_config.cudagraph_capture_sizes:
|
|
compilation_config.max_cudagraph_capture_size = (
|
|
compilation_config.cudagraph_capture_sizes[-1]
|
|
)
|
|
|
|
compilation_config.post_init_cudagraph_sizes()
|
|
mock_config.pad_for_cudagraph = (
|
|
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
|
|
)
|
|
|
|
return mock_config
|
|
|
|
|
|
class TestCudagraphDispatcher:
|
|
@pytest.mark.parametrize(
|
|
"cudagraph_mode_str,compilation_mode,lora_config",
|
|
[
|
|
# Test case 0: Full CG for mixed batches, no separate routine
|
|
("FULL", CompilationMode.NONE, False),
|
|
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
|
("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
|
# Test case 2: Full CG for uniform batches, no CG for mixed
|
|
("FULL_DECODE_ONLY", CompilationMode.NONE, False),
|
|
# Test case 3: PIECEWISE for all
|
|
("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
|
|
# Test case 4: PIECEWISE for all, specialize LoRA cases
|
|
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
|
],
|
|
)
|
|
def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
|
|
# Setup dispatcher
|
|
comp_config = CompilationConfig(
|
|
cudagraph_mode=cudagraph_mode_str,
|
|
mode=compilation_mode,
|
|
cudagraph_capture_sizes=[1, 8],
|
|
)
|
|
|
|
config = _create_vllm_config(
|
|
comp_config, max_num_seqs=8, lora_config=lora_config
|
|
)
|
|
if (
|
|
cudagraph_mode_str == "FULL_AND_PIECEWISE"
|
|
and compilation_mode == CompilationMode.NONE
|
|
):
|
|
with pytest.raises(AssertionError):
|
|
dispatcher = CudagraphDispatcher(config)
|
|
return
|
|
|
|
dispatcher = CudagraphDispatcher(config)
|
|
dispatcher.initialize_cudagraph_keys(
|
|
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
|
|
)
|
|
|
|
# Verify the key is initialized correctly
|
|
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
|
|
4 if lora_config else 2
|
|
)
|
|
else:
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
|
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
|
|
4 if lora_config else 2
|
|
)
|
|
else:
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
|
|
|
# Test dispatch logic
|
|
# 1. non-uniform batch, size in cudagraph size list
|
|
desc_full_exact = BatchDescriptor(
|
|
num_tokens=8,
|
|
uniform=False,
|
|
)
|
|
rt_mode, key = dispatcher.dispatch(
|
|
num_tokens=8, uniform_decode=False, has_lora=False
|
|
)
|
|
if cudagraph_mode_str == "FULL":
|
|
assert rt_mode == CUDAGraphMode.FULL
|
|
assert key == desc_full_exact
|
|
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
|
assert key == desc_full_exact
|
|
else:
|
|
assert rt_mode == CUDAGraphMode.NONE
|
|
|
|
# 2. uniform decode batch, size in cudagraph size list
|
|
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
|
|
rt_mode, key = dispatcher.dispatch(
|
|
num_tokens=8, uniform_decode=True, has_lora=False
|
|
)
|
|
if cudagraph_mode_str == "FULL":
|
|
assert rt_mode == CUDAGraphMode.FULL
|
|
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
|
|
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
|
|
assert rt_mode == CUDAGraphMode.FULL
|
|
assert key == desc_uniform_exact
|
|
elif cudagraph_mode_str == "PIECEWISE":
|
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
|
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
|
|
else:
|
|
assert rt_mode == CUDAGraphMode.NONE
|
|
|
|
# 3. No key match
|
|
rt_mode, key = dispatcher.dispatch(
|
|
num_tokens=15, uniform_decode=False, has_lora=False
|
|
)
|
|
assert rt_mode == CUDAGraphMode.NONE
|
|
assert key == BatchDescriptor(num_tokens=15)
|
|
|
|
# 4. Cascade attention should have a fall back mode
|
|
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
|
|
rt_mode, key = dispatcher.dispatch(
|
|
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
|
|
)
|
|
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
|
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
|
|
else:
|
|
assert rt_mode == CUDAGraphMode.NONE
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
|
class TestCUDAGraphWrapper:
|
|
def setup_method(self):
|
|
self.vllm_config = _create_vllm_config(CompilationConfig())
|
|
self.model = SimpleMLP().to("cuda")
|
|
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
|
|
self.input_tensor = torch.randn(1, 10, device="cuda")
|
|
|
|
def test_capture_and_replay(self):
|
|
wrapper = CUDAGraphWrapper(
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
)
|
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
|
|
|
# 0. global warmup
|
|
with set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
batch_descriptor=None,
|
|
):
|
|
wrapper(self.input_tensor)
|
|
|
|
# 1. Capture
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
|
batch_descriptor=batch_descriptor,
|
|
),
|
|
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
|
):
|
|
output1 = wrapper(self.input_tensor)
|
|
# capturing phase should generate a zero output
|
|
assert torch.allclose(output1, torch.zeros_like(output1))
|
|
mock_cuda_graph.assert_called_once()
|
|
|
|
assert batch_descriptor in wrapper.concrete_cudagraph_entries
|
|
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
|
|
assert entry.cudagraph is not None
|
|
|
|
# 2. Replay
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
|
batch_descriptor=batch_descriptor,
|
|
),
|
|
patch.object(
|
|
entry.cudagraph, "replay", wraps=entry.cudagraph.replay
|
|
) as mock_replay,
|
|
):
|
|
output2 = wrapper(self.input_tensor)
|
|
mock_replay.assert_called_once()
|
|
|
|
# Compare with eager output
|
|
eager_output = self.model(self.input_tensor)
|
|
torch.testing.assert_close(eager_output, output2)
|
|
|
|
def test_bypass_on_mode_mismatch(self):
|
|
wrapper = CUDAGraphWrapper(
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
)
|
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
|
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
|
batch_descriptor=batch_descriptor,
|
|
),
|
|
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
|
patch.object(
|
|
self.model, "forward", wraps=self.model.forward
|
|
) as mock_forward,
|
|
):
|
|
wrapper(self.input_tensor)
|
|
mock_cuda_graph.assert_not_called()
|
|
mock_forward.assert_called_once()
|
|
assert not wrapper.concrete_cudagraph_entries
|
|
|
|
def test_bypass_on_mode_none(self):
|
|
wrapper = CUDAGraphWrapper(
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
)
|
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
|
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
batch_descriptor=batch_descriptor,
|
|
),
|
|
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
|
):
|
|
wrapper(self.input_tensor)
|
|
mock_cuda_graph.assert_not_called()
|
|
assert not wrapper.concrete_cudagraph_entries
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
|
class TestCudagraphIntegration:
|
|
def setup_method(self):
|
|
# only FULL mode for non-uniform batches
|
|
self.comp_config = CompilationConfig(
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
cudagraph_mode="FULL",
|
|
cudagraph_capture_sizes=[10, 20],
|
|
)
|
|
self.vllm_config = _create_vllm_config(self.comp_config)
|
|
self.dispatcher = CudagraphDispatcher(self.vllm_config)
|
|
self.dispatcher.initialize_cudagraph_keys(
|
|
self.comp_config.cudagraph_mode, uniform_decode_query_len=1
|
|
)
|
|
|
|
def _run_and_monitor_call(
|
|
self, wrapper, input_tensor, runtime_mode, batch_descriptor
|
|
):
|
|
"""Helper to run a single call and monitor the action."""
|
|
|
|
with (
|
|
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
|
|
patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
|
|
):
|
|
entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
|
|
|
|
context = set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
)
|
|
mock_replay = MagicMock()
|
|
if entry and entry.cudagraph:
|
|
with (
|
|
context,
|
|
patch.object(
|
|
entry.cudagraph, "replay", new_callable=MagicMock
|
|
) as mock_replay,
|
|
):
|
|
wrapper(input_tensor)
|
|
else:
|
|
with context:
|
|
wrapper(input_tensor)
|
|
|
|
if mock_graph_context.called:
|
|
# note that this is globally mocked, so it will be detected
|
|
# even whether called by the inner or outer wrapper
|
|
return "capture_global"
|
|
if mock_replay.called:
|
|
# only for outer wrapper
|
|
return "replay"
|
|
if mock_runnable.call_count > 0:
|
|
# only for outer wrapper
|
|
return "bypass"
|
|
return "unknown"
|
|
|
|
@create_new_process_for_each_test("spawn")
|
|
def test_capture_replay_bypass_logic(self):
|
|
model = SimpleMLP().to("cuda")
|
|
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
|
|
max_bs = 16
|
|
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
|
|
input_1 = persistent_input_buffer[:1]
|
|
input_2 = persistent_input_buffer[:2]
|
|
input_3 = persistent_input_buffer[:3]
|
|
|
|
desc_1 = BatchDescriptor(num_tokens=1)
|
|
desc_2 = BatchDescriptor(num_tokens=2)
|
|
desc_3_unseen = BatchDescriptor(num_tokens=3)
|
|
|
|
# 0. global warmup
|
|
with set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
batch_descriptor=None,
|
|
):
|
|
full_wrapper(input_1)
|
|
|
|
rt_mode, key = self.dispatcher.dispatch(desc_1)
|
|
# 1. Capture first shape
|
|
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
|
|
assert action == "capture_global"
|
|
|
|
# 2. Replay first shape
|
|
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
|
|
assert action == "replay"
|
|
|
|
rt_mode, key = self.dispatcher.dispatch(desc_2)
|
|
# 3. Capture second shape
|
|
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
|
|
assert action == "capture_global"
|
|
|
|
# 4. Replay second shape
|
|
action = self._run_and_monitor_call(
|
|
full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
|
|
)
|
|
assert action == "replay"
|
|
|
|
# 5. Bypass if no key match
|
|
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
|
|
assert rt_mode == CUDAGraphMode.NONE
|
|
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
|
|
assert action == "bypass"
|
|
|
|
# capture unseen shape is not allowed after disable
|
|
set_cudagraph_capturing_enabled(False)
|
|
with pytest.raises(RuntimeError):
|
|
self._run_and_monitor_call(
|
|
full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
|
|
)
|
|
set_cudagraph_capturing_enabled(True)
|
|
|
|
@create_new_process_for_each_test("spawn")
|
|
def test_nested_wrappers(self):
|
|
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
|
|
model = SimpleMLP().to("cuda")
|
|
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
|
|
input_1 = torch.randn(1, 10, device="cuda")
|
|
|
|
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
|
|
inner_model = SimpleMLP().to("cuda")
|
|
piecewise_wrapper = CUDAGraphWrapper(
|
|
inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
|
|
)
|
|
inner_model.forward = MagicMock(wraps=inner_model.forward)
|
|
outer_model = SimpleMLP().to("cuda")
|
|
# When outer model is called, it calls the piecewise_wrapper
|
|
outer_model.forward = MagicMock(
|
|
wraps=outer_model.forward, side_effect=piecewise_wrapper
|
|
)
|
|
full_wrapper = CUDAGraphWrapper(
|
|
outer_model, self.vllm_config, CUDAGraphMode.FULL
|
|
)
|
|
|
|
desc_1 = BatchDescriptor(num_tokens=1)
|
|
|
|
# 0. global warmup
|
|
with set_forward_context(
|
|
attn_metadata=None,
|
|
vllm_config=self.vllm_config,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
batch_descriptor=None,
|
|
):
|
|
full_wrapper(input_1)
|
|
|
|
# --- Test runtime mode FULL---
|
|
# Run with FULL mode context. Expect outer wrapper to capture.
|
|
# The inner mock should be called once inside the graph capture.
|
|
outer_model.forward.reset_mock()
|
|
inner_model.forward.reset_mock()
|
|
action = self._run_and_monitor_call(
|
|
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
|
|
)
|
|
assert action == "capture_global"
|
|
assert outer_model.forward.call_count == 1
|
|
assert inner_model.forward.call_count == 1
|
|
|
|
# Run again. Expect outer wrapper to replay.
|
|
# The outer model should NOT be called because the whole graph
|
|
# is replayed.
|
|
action = self._run_and_monitor_call(
|
|
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
|
|
)
|
|
assert action == "replay"
|
|
assert outer_model.forward.call_count == 1 # No new call
|
|
assert inner_model.forward.call_count == 1
|
|
|
|
# --- Test runtime mode PIECEWISE ---
|
|
outer_model.forward.reset_mock()
|
|
inner_model.forward.reset_mock()
|
|
# Run with PIECEWISE mode context.
|
|
# Expect outer wrapper to bypass and call inner wrapper.
|
|
# Inner wrapper should capture.
|
|
action = self._run_and_monitor_call(
|
|
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
|
|
)
|
|
assert action == "capture_global"
|
|
assert outer_model.forward.call_count == 1
|
|
assert inner_model.forward.call_count == 1
|
|
|
|
# Run again with PIECEWISE.
|
|
# Outer bypasses, inner replays.
|
|
action = self._run_and_monitor_call(
|
|
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
|
|
)
|
|
assert action == "bypass"
|
|
assert outer_model.forward.call_count == 2
|
|
assert inner_model.forward.call_count == 1
|