mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import fx as fx
|
|
from torch import nn
|
|
|
|
# This import automatically registers `torch.ops.silly.attention`
|
|
import tests.compile.silly_attention # noqa
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.compilation.inductor_pass import (
|
|
InductorPass,
|
|
get_pass_context,
|
|
)
|
|
from vllm.config import (
|
|
VllmConfig,
|
|
set_current_vllm_config,
|
|
)
|
|
from vllm.config.compilation import CompilationConfig, CompilationMode
|
|
from vllm.config.scheduler import SchedulerConfig
|
|
from vllm.config.utils import Range
|
|
from vllm.forward_context import set_forward_context
|
|
|
|
BATCH_SIZE = 64
|
|
MLP_SIZE = 128
|
|
|
|
|
|
@support_torch_compile
|
|
class TestModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + x
|
|
attn_output = torch.empty_like(x)
|
|
torch.ops.silly.attention(x, x, x, attn_output)
|
|
x = attn_output
|
|
x = x * 3
|
|
return x
|
|
|
|
|
|
@torch.inference_mode
|
|
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
|
|
with set_forward_context({}, vllm_config=vllm_config):
|
|
model(torch.randn(BATCH_SIZE, MLP_SIZE))
|
|
for batch_size in batch_sizes:
|
|
model(torch.randn(batch_size, MLP_SIZE))
|
|
|
|
|
|
class PostGradRangeChecker(InductorPass):
|
|
def __init__(self, ranges: list[Range]):
|
|
self.ranges = ranges
|
|
self.num_calls = 0
|
|
|
|
def __call__(self, graph: fx.Graph):
|
|
compile_range = get_pass_context().compile_range
|
|
assert compile_range in self.ranges, (
|
|
f"Compile range {compile_range} not in {self.ranges}"
|
|
)
|
|
self.num_calls += 1
|
|
|
|
def uuid(self) -> str:
|
|
state: dict[str, Any] = {}
|
|
return InductorPass.hash_dict(state)
|
|
|
|
|
|
def test_compile_ranges(use_fresh_inductor_cache):
|
|
post_grad_range_checker = PostGradRangeChecker(
|
|
[
|
|
Range(start=1, end=8),
|
|
Range(start=16, end=16),
|
|
Range(start=9, end=32),
|
|
Range(start=64, end=64),
|
|
Range(start=33, end=8192),
|
|
]
|
|
)
|
|
torch.set_default_device("cuda")
|
|
vllm_config = VllmConfig(
|
|
scheduler_config=SchedulerConfig(
|
|
max_num_batched_tokens=8192,
|
|
),
|
|
compilation_config=CompilationConfig(
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
compile_ranges_split_points=[8, 32],
|
|
compile_sizes=[16, 64, 128],
|
|
inductor_compile_config={
|
|
"post_grad_custom_post_pass": post_grad_range_checker,
|
|
},
|
|
),
|
|
)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
model = TestModel(vllm_config=vllm_config, prefix="").eval()
|
|
# Number of compilations: 3 for each compile range + 2 compile sizes
|
|
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
|
|
|
|
with compilation_counter.expect(
|
|
num_graphs_seen=1,
|
|
num_piecewise_graphs_seen=1,
|
|
num_backend_compilations=5,
|
|
):
|
|
run_model(vllm_config, model, batch_sizes)
|
|
assert post_grad_range_checker.num_calls == 5
|
|
|
|
|
|
def test_compile_config_get_compile_ranges():
|
|
compilation_config = CompilationConfig(
|
|
compile_ranges_split_points=[8, 32],
|
|
)
|
|
VllmConfig(
|
|
scheduler_config=SchedulerConfig(
|
|
max_num_batched_tokens=8192,
|
|
),
|
|
compilation_config=compilation_config,
|
|
)
|
|
assert compilation_config.get_compile_ranges() == [
|
|
Range(start=1, end=8),
|
|
Range(start=9, end=32),
|
|
Range(start=33, end=8192),
|
|
]
|
|
|
|
|
|
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
|
# To force multiple compilations, we disable the compile cache
|
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
|
|
|
post_grad_range_checker = PostGradRangeChecker(
|
|
ranges=[
|
|
Range(start=1, end=8),
|
|
Range(start=9, end=8192),
|
|
]
|
|
)
|
|
scheduler_config = SchedulerConfig(
|
|
max_num_batched_tokens=8192,
|
|
)
|
|
torch.set_default_device("cuda")
|
|
|
|
def create_vllm_config():
|
|
return VllmConfig(
|
|
scheduler_config=scheduler_config,
|
|
compilation_config=CompilationConfig(
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
compile_ranges_split_points=[8],
|
|
inductor_compile_config={
|
|
"post_grad_custom_post_pass": post_grad_range_checker,
|
|
},
|
|
),
|
|
)
|
|
|
|
vllm_config_1 = create_vllm_config()
|
|
with set_current_vllm_config(vllm_config_1):
|
|
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
|
|
batch_sizes = [1, 16]
|
|
run_model(vllm_config_1, model1, batch_sizes)
|
|
assert post_grad_range_checker.num_calls == 2
|
|
|
|
post_grad_range_checker.num_calls = 0
|
|
# Create a new vllm config with the new pass context
|
|
vllm_config_2 = create_vllm_config()
|
|
with set_current_vllm_config(vllm_config_2):
|
|
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
|
|
batch_sizes = [4, 32]
|
|
run_model(vllm_config_2, model2, batch_sizes)
|
|
# Check that cache is used, so the number of calls
|
|
# should be 0
|
|
assert post_grad_range_checker.num_calls == 0
|