vllm/tests/compile/test_compile_ranges.py
Ilya Markov 4e26d3b09e
[Compile] Conditional compilation. Introduce compile_ranges (#24252)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
2025-12-05 18:17:32 +00:00

169 lines
5.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from torch import fx as fx
from torch import nn
# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.inductor_pass import (
InductorPass,
get_pass_context,
)
from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context
BATCH_SIZE = 64
MLP_SIZE = 128
@support_torch_compile
class TestModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(BATCH_SIZE, MLP_SIZE))
for batch_size in batch_sizes:
model(torch.randn(batch_size, MLP_SIZE))
class PostGradRangeChecker(InductorPass):
def __init__(self, ranges: list[Range]):
self.ranges = ranges
self.num_calls = 0
def __call__(self, graph: fx.Graph):
compile_range = get_pass_context().compile_range
assert compile_range in self.ranges, (
f"Compile range {compile_range} not in {self.ranges}"
)
self.num_calls += 1
def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)
def test_compile_ranges(use_fresh_inductor_cache):
post_grad_range_checker = PostGradRangeChecker(
[
Range(start=1, end=8),
Range(start=16, end=16),
Range(start=9, end=32),
Range(start=64, end=64),
Range(start=33, end=8192),
]
)
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
compile_sizes=[16, 64, 128],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=5,
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5
def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32],
)
VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
),
compilation_config=compilation_config,
)
assert compilation_config.get_compile_ranges() == [
Range(start=1, end=8),
Range(start=9, end=32),
Range(start=33, end=8192),
]
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
post_grad_range_checker = PostGradRangeChecker(
ranges=[
Range(start=1, end=8),
Range(start=9, end=8192),
]
)
scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
)
torch.set_default_device("cuda")
def create_vllm_config():
return VllmConfig(
scheduler_config=scheduler_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
vllm_config_1 = create_vllm_config()
with set_current_vllm_config(vllm_config_1):
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
batch_sizes = [1, 16]
run_model(vllm_config_1, model1, batch_sizes)
assert post_grad_range_checker.num_calls == 2
post_grad_range_checker.num_calls = 0
# Create a new vllm config with the new pass context
vllm_config_2 = create_vllm_config()
with set_current_vllm_config(vllm_config_2):
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
batch_sizes = [4, 32]
run_model(vllm_config_2, model2, batch_sizes)
# Check that cache is used, so the number of calls
# should be 0
assert post_grad_range_checker.num_calls == 0