vllm/tests/v1/cudagraph/test_cudagraph_dispatch.py
Harry Mellor 951445a52d
Remove default values from InitVars so that they're not stored (#29859)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-12-02 12:16:37 +00:00

470 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# Helper MLP for testing
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
return self.fc2(self.fc1(x))
def _create_vllm_config(
compilation_config: CompilationConfig,
max_num_seqs: int = 8,
lora_config: bool = False,
) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig.default_factory(
max_num_seqs=max_num_seqs,
)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:
compilation_config.max_cudagraph_capture_size = (
compilation_config.cudagraph_capture_sizes[-1]
)
compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"cudagraph_mode_str,compilation_mode,lora_config",
[
# Test case 0: Full CG for mixed batches, no separate routine
("FULL", CompilationMode.NONE, False),
# Test case 1: Full CG for uniform batches, piecewise for mixed
("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
# Test case 2: Full CG for uniform batches, no CG for mixed
("FULL_DECODE_ONLY", CompilationMode.NONE, False),
# Test case 3: PIECEWISE for all
("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
# Test case 4: PIECEWISE for all, specialize LoRA cases
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
],
)
def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
mode=compilation_mode,
cudagraph_capture_sizes=[1, 8],
)
config = _create_vllm_config(
comp_config, max_num_seqs=8, lora_config=lora_config
)
if (
cudagraph_mode_str == "FULL_AND_PIECEWISE"
and compilation_mode == CompilationMode.NONE
):
with pytest.raises(AssertionError):
dispatcher = CudagraphDispatcher(config)
return
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
)
# Verify the key is initialized correctly
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform=False,
)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
rt_mode, key = dispatcher.dispatch(
num_tokens=15, uniform_decode=False, has_lora=False
)
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda")
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
batch_descriptor = BatchDescriptor(num_tokens=10)
# 0. global warmup
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
):
wrapper(self.input_tensor)
# 1. Capture
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
):
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output
assert torch.allclose(output1, torch.zeros_like(output1))
mock_cuda_graph.assert_called_once()
assert batch_descriptor in wrapper.concrete_cudagraph_entries
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
assert entry.cudagraph is not None
# 2. Replay
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor,
),
patch.object(
entry.cudagraph, "replay", wraps=entry.cudagraph.replay
) as mock_replay,
):
output2 = wrapper(self.input_tensor)
mock_replay.assert_called_once()
# Compare with eager output
eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2)
def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
batch_descriptor = BatchDescriptor(num_tokens=10)
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
patch.object(
self.model, "forward", wraps=self.model.forward
) as mock_forward,
):
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries
def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
batch_descriptor = BatchDescriptor(num_tokens=10)
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
):
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
assert not wrapper.concrete_cudagraph_entries
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
def setup_method(self):
# only FULL mode for non-uniform batches
self.comp_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode="FULL",
cudagraph_capture_sizes=[10, 20],
)
self.vllm_config = _create_vllm_config(self.comp_config)
self.dispatcher = CudagraphDispatcher(self.vllm_config)
self.dispatcher.initialize_cudagraph_keys(
self.comp_config.cudagraph_mode, uniform_decode_query_len=1
)
def _run_and_monitor_call(
self, wrapper, input_tensor, runtime_mode, batch_descriptor
):
"""Helper to run a single call and monitor the action."""
with (
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
):
entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
context = set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=runtime_mode,
batch_descriptor=batch_descriptor,
)
mock_replay = MagicMock()
if entry and entry.cudagraph:
with (
context,
patch.object(
entry.cudagraph, "replay", new_callable=MagicMock
) as mock_replay,
):
wrapper(input_tensor)
else:
with context:
wrapper(input_tensor)
if mock_graph_context.called:
# note that this is globally mocked, so it will be detected
# even whether called by the inner or outer wrapper
return "capture_global"
if mock_replay.called:
# only for outer wrapper
return "replay"
if mock_runnable.call_count > 0:
# only for outer wrapper
return "bypass"
return "unknown"
@create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3]
desc_1 = BatchDescriptor(num_tokens=1)
desc_2 = BatchDescriptor(num_tokens=2)
desc_3_unseen = BatchDescriptor(num_tokens=3)
# 0. global warmup
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
):
full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "capture_global"
# 2. Replay first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
assert action == "capture_global"
# 4. Replay second shape
action = self._run_and_monitor_call(
full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
)
assert action == "replay"
# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
assert action == "bypass"
# capture unseen shape is not allowed after disable
set_cudagraph_capturing_enabled(False)
with pytest.raises(RuntimeError):
self._run_and_monitor_call(
full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
)
set_cudagraph_capturing_enabled(True)
@create_new_process_for_each_test("spawn")
def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda")
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda")
piecewise_wrapper = CUDAGraphWrapper(
inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
)
inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda")
# When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock(
wraps=outer_model.forward, side_effect=piecewise_wrapper
)
full_wrapper = CUDAGraphWrapper(
outer_model, self.vllm_config, CUDAGraphMode.FULL
)
desc_1 = BatchDescriptor(num_tokens=1)
# 0. global warmup
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
):
full_wrapper(input_1)
# --- Test runtime mode FULL---
# Run with FULL mode context. Expect outer wrapper to capture.
# The inner mock should be called once inside the graph capture.
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
action = self._run_and_monitor_call(
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again. Expect outer wrapper to replay.
# The outer model should NOT be called because the whole graph
# is replayed.
action = self._run_and_monitor_call(
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
)
assert action == "replay"
assert outer_model.forward.call_count == 1 # No new call
assert inner_model.forward.call_count == 1
# --- Test runtime mode PIECEWISE ---
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
# Run with PIECEWISE mode context.
# Expect outer wrapper to bypass and call inner wrapper.
# Inner wrapper should capture.
action = self._run_and_monitor_call(
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again with PIECEWISE.
# Outer bypasses, inner replays.
action = self._run_and_monitor_call(
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
)
assert action == "bypass"
assert outer_model.forward.call_count == 2
assert inner_model.forward.call_count == 1