From b8a93076d36eff5cff8a89f99a7370d0cc6f0e98 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Wed, 10 Sep 2025 02:05:25 +0800 Subject: [PATCH] [CI] execute all piecewise compilation tests together (#24502) Signed-off-by: zjy0516 --- .buildkite/test-pipeline.yaml | 6 +- .../compile/piecewise/test_multiple_graphs.py | 28 +-------- tests/compile/piecewise/test_simple.py | 43 +++---------- tests/compile/piecewise/test_toy_llama.py | 27 +------- tests/compile/silly_attention.py | 63 +++++++++++++++++++ tests/compile/test_decorator.py | 31 ++------- 6 files changed, 81 insertions(+), 117 deletions(-) create mode 100644 tests/compile/silly_attention.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b0d4c4456d339..8f2f6083b0305 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -379,11 +379,7 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py - # these tests need to be separated, cannot combine - - pytest -v -s compile/piecewise/test_simple.py - - pytest -v -s compile/piecewise/test_toy_llama.py - - pytest -v -s compile/piecewise/test_full_cudagraph.py - - pytest -v -s compile/piecewise/test_multiple_graphs.py + - pytest -v -s compile/piecewise/ - label: PyTorch Fullgraph Test # 20min timeout_in_minutes: 30 diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index aee2acbd490ee..5cfebfce9ea2a 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -4,9 +4,9 @@ Test (piecewise) compilation with a simple model where multiple submodules are compiled and graph captured separately. """ + import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter @@ -15,10 +15,9 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 @@ -26,27 +25,6 @@ HIDDEN_SIZE = 1024 RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2d1a72d44ec70..84f4945c82725 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,10 +4,10 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ + import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile @@ -15,35 +15,9 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from ..silly_attention import get_global_counter, reset_global_counter @support_torch_compile @@ -59,8 +33,7 @@ class SillyModel(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overall effect: - x += 1 - x[0] += 2 + x = 3 * x + 19 global_counter += 2 """ x = x + 1 @@ -78,6 +51,7 @@ class SillyModel(nn.Module): @pytest.mark.parametrize("use_inductor", [True, False]) +@torch.inference_mode() def test_simple_piecewise_compile(use_inductor): assert VLLM_USE_V1 @@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( None, vllm_config=vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bcfd0d834c5db..cba7517647e51 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -14,38 +14,15 @@ from typing import Any, Optional import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 @dataclass diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py new file mode 100644 index 0000000000000..13eb0bf4b1fa1 --- /dev/null +++ b/tests/compile/silly_attention.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom silly attention for compilation tests. +Centralizes custom operation definitions to avoid duplicate registrations. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +# import this file will automatically register +# torch ops for testing (like silly.attention) +silly_lib = Library("silly", "FRAGMENT") + +# Global counter that counts the number of times attention is invoked +_global_counter = 0 + + +def get_global_counter(): + """Get the current global counter value""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0""" + global _global_counter + _global_counter = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """ + Unified attention implementation that depends on + all inputs and affects the output. + Always increments a global counter that tests can use or ignore. + """ + global _global_counter + + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d56..d73586d53ff3e 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, @@ -10,36 +9,14 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode): @@ -151,7 +128,7 @@ def test_ignore_torch_compile_decorator(): run_model(vllm_config, mod_C, cudagraph_runtime_mode) -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=True @support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. kv_sharing_fast_prefill) @@ -173,7 +150,7 @@ class B(nn.Module): return x -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=False @support_torch_compile(enable_if=lambda vllm_config: not vllm_config. cache_config.kv_sharing_fast_prefill)