[Graph Partition] pass tests for decorator (#26831)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng 2025-10-14 23:39:48 -07:00 committed by GitHub
parent 8c851f6d04
commit f0862eae43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
@ -14,6 +15,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
@ -65,19 +67,40 @@ def run_model(
return output.cpu()
def test_ignore_torch_compile_decorator():
# vllmcompile
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# piecewise
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both?
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
expected_num_graphs_seen = 1
expected_num_cudagraph_captured = (
4 # num_cudagraph_sizes * num cudagraphs to capture
)
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
expected_num_piecewise_graphs_seen = 3
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
@support_torch_compile
class A(nn.Module):
def __init__(
@ -104,12 +127,11 @@ def test_ignore_torch_compile_decorator():
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
@ -131,12 +153,11 @@ def test_ignore_torch_compile_decorator():
# C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
@ -179,7 +200,15 @@ class A(nn.Module):
return x
def test_conditional_compile_enable_if():
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=True,
@ -189,7 +218,7 @@ def test_conditional_compile_enable_if():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@ -197,17 +226,26 @@ def test_conditional_compile_enable_if():
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 2
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
else:
expected_num_piecewise_graphs_seen = 6
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
# num_cudagraph_sizes * num cudagraphable graphs to capture
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=False, # TODO test both?
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
# 3 attn ops and 4 non-attn ops
expected_num_piecewise_graphs_seen = 7
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=7,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
# num_cudagraph_sizes * num cudagraphable graphs to capture
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)