mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
136 lines
4.0 KiB
Python
136 lines
4.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
|
from vllm.config import (
|
|
CompilationConfig,
|
|
CompilationMode,
|
|
VllmConfig,
|
|
set_current_vllm_config,
|
|
)
|
|
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
|
if x.size()[0] >= 4:
|
|
return x * 2
|
|
else:
|
|
return x * 100
|
|
|
|
|
|
class MyWrapper(TorchCompileWithNoGuardsWrapper):
|
|
def __init__(self, model):
|
|
self.model = model
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor): # type: ignore[override]
|
|
# this is the function to be compiled
|
|
return self.model(x)
|
|
|
|
|
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
|
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
|
|
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
|
|
# Set the environment variable for this test
|
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
|
|
|
# Create a proper vLLM config instead of mocking
|
|
vllm_config = VllmConfig()
|
|
vllm_config.compilation_config = CompilationConfig()
|
|
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
|
vllm_config.compilation_config.backend = "inductor"
|
|
|
|
# Test DYNAMO_TRACE_ONCE
|
|
with set_current_vllm_config(vllm_config):
|
|
torch._dynamo.reset()
|
|
mod = MyMod()
|
|
wrapper = MyWrapper(mod)
|
|
|
|
# First call should trigger compilation
|
|
x = torch.tensor([1, 2, 3, 4])
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
|
|
result1 = wrapper(x)
|
|
expected1 = torch.tensor([2, 4, 6, 8])
|
|
assert torch.allclose(result1, expected1), (
|
|
f"Expected {expected1}, got {result1}"
|
|
)
|
|
|
|
# Second call should use compiled code
|
|
x2 = torch.tensor([1, 2, 3])
|
|
result2 = wrapper(x2)
|
|
expected2 = torch.tensor([2, 4, 6])
|
|
assert torch.allclose(result2, expected2), (
|
|
f"Expected {expected2}, got {result2}"
|
|
)
|
|
|
|
# without the wrapper result would be different.
|
|
result3 = mod(x2)
|
|
expected3 = torch.tensor([100, 200, 300])
|
|
|
|
assert torch.allclose(result3, expected3), (
|
|
f"Expected {result3}, got {expected3}"
|
|
)
|
|
|
|
# with STOCK_TORCH_COMPILE we do not remove guards.
|
|
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
|
|
torch._dynamo.reset()
|
|
with set_current_vllm_config(vllm_config):
|
|
mod = MyMod()
|
|
wrapper = MyWrapper(mod)
|
|
|
|
# First call should trigger compilation
|
|
x = torch.tensor([1, 2, 3, 4])
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
|
|
result1 = wrapper(x)
|
|
expected1 = torch.tensor([2, 4, 6, 8])
|
|
assert torch.allclose(result1, expected1), (
|
|
f"Expected {expected1}, got {result1}"
|
|
)
|
|
|
|
# Second call should triger another compilation
|
|
x2 = torch.tensor([1, 2, 3])
|
|
result2 = wrapper(x2)
|
|
expected2 = torch.tensor([100, 200, 300])
|
|
assert torch.allclose(result2, expected2), (
|
|
f"Expected {expected2}, got {result2}"
|
|
)
|
|
|
|
# NO_COMPILATION level not supported.
|
|
vllm_config.compilation_config.mode = None
|
|
torch._dynamo.reset()
|
|
with set_current_vllm_config(vllm_config):
|
|
torch._dynamo.reset()
|
|
mod = MyMod()
|
|
|
|
try:
|
|
wrapper = MyWrapper(mod)
|
|
except Exception:
|
|
return
|
|
raise AssertionError("expected an exception to be raised")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run with both parameter values
|
|
|
|
class MockMonkeypatch:
|
|
def setenv(self, name, value):
|
|
os.environ[name] = value
|
|
|
|
mp = MockMonkeypatch()
|
|
|
|
print("Testing with VLLM_USE_BYTECODE_HOOK=False")
|
|
test_torch_compile_wrapper(False, mp)
|
|
|
|
print("Testing with VLLM_USE_BYTECODE_HOOK=True")
|
|
test_torch_compile_wrapper(True, mp)
|
|
|
|
print("All tests passed!")
|