mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Graph Partition] pass tests for decorator (#26831)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
parent
8c851f6d04
commit
f0862eae43
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -14,6 +15,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
@ -65,19 +67,40 @@ def run_model(
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
# vllmcompile
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
|
||||
# disable compile cache so that we can count the number of compilations
|
||||
# appropriately
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both?
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
expected_num_graphs_seen = 1
|
||||
expected_num_cudagraph_captured = (
|
||||
4 # num_cudagraph_sizes * num cudagraphs to capture
|
||||
)
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 1
|
||||
expected_num_piecewise_capturable_graphs_seen = 1
|
||||
expected_num_backend_compilations = 1
|
||||
else:
|
||||
expected_num_piecewise_graphs_seen = 3
|
||||
expected_num_piecewise_capturable_graphs_seen = 2
|
||||
expected_num_backend_compilations = 2
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
def __init__(
|
||||
@ -104,12 +127,11 @@ def test_ignore_torch_compile_decorator():
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=expected_num_graphs_seen,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
@ -131,12 +153,11 @@ def test_ignore_torch_compile_decorator():
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=expected_num_graphs_seen,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
):
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
@ -179,7 +200,15 @@ class A(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def test_conditional_compile_enable_if():
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
|
||||
# disable compile cache so that we can count the number of compilations
|
||||
# appropriately
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True,
|
||||
@ -189,7 +218,7 @@ def test_conditional_compile_enable_if():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
@ -197,17 +226,26 @@ def test_conditional_compile_enable_if():
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 2
|
||||
expected_num_piecewise_capturable_graphs_seen = 2
|
||||
expected_num_backend_compilations = 2
|
||||
else:
|
||||
expected_num_piecewise_graphs_seen = 6
|
||||
expected_num_piecewise_capturable_graphs_seen = 4
|
||||
expected_num_backend_compilations = 4
|
||||
|
||||
# A has support_torch_compile but enable_if fn returns False
|
||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||
# to be compiled
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
# num_cudagraph_sizes * num cudagraphable graphs to capture
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=False, # TODO test both?
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 1
|
||||
expected_num_piecewise_capturable_graphs_seen = 1
|
||||
expected_num_backend_compilations = 1
|
||||
else:
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
expected_num_piecewise_graphs_seen = 7
|
||||
expected_num_piecewise_capturable_graphs_seen = 4
|
||||
expected_num_backend_compilations = 4
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=7,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
# num_cudagraph_sizes * num cudagraphable graphs to capture
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user