mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-22 02:17:02 +08:00
[CI] execute all piecewise compilation tests together (#24502)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
c3f9773b2c
commit
b8a93076d3
@ -379,11 +379,7 @@ steps:
|
|||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
# these tests need to be separated, cannot combine
|
- pytest -v -s compile/piecewise/
|
||||||
- pytest -v -s compile/piecewise/test_simple.py
|
|
||||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
|
||||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
|
||||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Test # 20min
|
- label: PyTorch Fullgraph Test # 20min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
|
|||||||
@ -4,9 +4,9 @@
|
|||||||
Test (piecewise) compilation with a simple model where multiple submodules
|
Test (piecewise) compilation with a simple model where multiple submodules
|
||||||
are compiled and graph captured separately.
|
are compiled and graph captured separately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.backends import set_model_tag
|
from vllm.compilation.backends import set_model_tag
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
@ -15,10 +15,9 @@ from vllm.compilation.decorators import (ignore_torch_compile,
|
|||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
from .. import silly_attention # noqa: F401
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
MLP_SIZE = 128
|
MLP_SIZE = 128
|
||||||
@ -26,27 +25,6 @@ HIDDEN_SIZE = 1024
|
|||||||
RANDOM_SEED = 0
|
RANDOM_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ParentModel(nn.Module):
|
class ParentModel(nn.Module):
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,10 @@
|
|||||||
Test the piecewise compilation with a simple model so that we
|
Test the piecewise compilation with a simple model so that we
|
||||||
can exactly calculate the expected output and side effects.
|
can exactly calculate the expected output and side effects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
@ -15,35 +15,9 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
|||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
global_counter = 0
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
|
from ..silly_attention import get_global_counter, reset_global_counter
|
||||||
# create a library to hold the custom op
|
|
||||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
global global_counter
|
|
||||||
global_counter += 1
|
|
||||||
print(f"{global_counter=}")
|
|
||||||
out.copy_(q)
|
|
||||||
out[0] += 1
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
@ -59,8 +33,7 @@ class SillyModel(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Overall effect:
|
Overall effect:
|
||||||
x += 1
|
x = 3 * x + 19
|
||||||
x[0] += 2
|
|
||||||
global_counter += 2
|
global_counter += 2
|
||||||
"""
|
"""
|
||||||
x = x + 1
|
x = x + 1
|
||||||
@ -78,6 +51,7 @@ class SillyModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
def test_simple_piecewise_compile(use_inductor):
|
def test_simple_piecewise_compile(use_inductor):
|
||||||
assert VLLM_USE_V1
|
assert VLLM_USE_V1
|
||||||
|
|
||||||
@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
model(torch.randn(1).cuda())
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input = torch.zeros(2).cuda()
|
input = torch.zeros(2).cuda()
|
||||||
global global_counter
|
reset_global_counter()
|
||||||
global_counter = 0
|
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
assert global_counter == 2
|
assert get_global_counter() == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||||
|
|||||||
@ -14,38 +14,15 @@ from typing import Any, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
from .. import silly_attention # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
63
tests/compile/silly_attention.py
Normal file
63
tests/compile/silly_attention.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Shared PyTorch custom silly attention for compilation tests.
|
||||||
|
Centralizes custom operation definitions to avoid duplicate registrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.library import Library
|
||||||
|
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
# Shared library for all compilation test operations
|
||||||
|
# Using "silly" namespace to match existing test expectations
|
||||||
|
# import this file will automatically register
|
||||||
|
# torch ops for testing (like silly.attention)
|
||||||
|
silly_lib = Library("silly", "FRAGMENT")
|
||||||
|
|
||||||
|
# Global counter that counts the number of times attention is invoked
|
||||||
|
_global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_counter():
|
||||||
|
"""Get the current global counter value"""
|
||||||
|
return _global_counter
|
||||||
|
|
||||||
|
|
||||||
|
def reset_global_counter():
|
||||||
|
"""Reset the global counter to 0"""
|
||||||
|
global _global_counter
|
||||||
|
_global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Unified attention implementation that depends on
|
||||||
|
all inputs and affects the output.
|
||||||
|
Always increments a global counter that tests can use or ignore.
|
||||||
|
"""
|
||||||
|
global _global_counter
|
||||||
|
|
||||||
|
# Always increment the global counter
|
||||||
|
_global_counter += 1
|
||||||
|
|
||||||
|
# Unified implementation that depends on all inputs
|
||||||
|
out.copy_(q + k + v)
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
"""Fake implementation for testing"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Register the unified attention operation
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="attention",
|
||||||
|
op_func=silly_attention,
|
||||||
|
mutates_args=["out"],
|
||||||
|
fake_impl=silly_attention_fake,
|
||||||
|
target_lib=silly_lib,
|
||||||
|
)
|
||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||||
@ -10,36 +9,14 @@ from vllm.compilation.decorators import (ignore_torch_compile,
|
|||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
from . import silly_attention # noqa: F401
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
MLP_SIZE = 128
|
MLP_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
cudagraph_runtime_mode: CUDAGraphMode):
|
||||||
@ -151,7 +128,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||||
|
|
||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
kv_sharing_fast_prefill)
|
kv_sharing_fast_prefill)
|
||||||
@ -173,7 +150,7 @@ class B(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Only enable torch.compile if
|
# Only enable torch.compile if
|
||||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||||
cache_config.kv_sharing_fast_prefill)
|
cache_config.kv_sharing_fast_prefill)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user