mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:25:48 +08:00
Avoid bytecode hook and simplify TorchCompileWrapperWithCustomDipatch (#25110)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
5a84b76b86
commit
2e0ad629b0
@ -22,6 +22,8 @@ from vllm.config import (
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
@ -193,7 +195,14 @@ def run_model(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_multi_graph_piecewise_compile(
|
||||
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
|
||||
):
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ from vllm.config import (
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
||||
@ -124,6 +126,7 @@ def _run_simple_model(
|
||||
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
@torch.inference_mode()
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly::attention"],
|
||||
|
||||
@ -29,6 +29,8 @@ from vllm.config import (
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
||||
("inductor", True), # Inductor, Inductor partition
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_toy_llama(
|
||||
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||
):
|
||||
@ -513,4 +516,8 @@ def benchmark():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark()
|
||||
# Protect against subprocess reimport when using spawn_new_process_for_each_test
|
||||
import os
|
||||
|
||||
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
|
||||
benchmark()
|
||||
|
||||
@ -2,59 +2,134 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||
if cache is not None:
|
||||
return x + cache
|
||||
return x * 2
|
||||
if x.size()[0] >= 4:
|
||||
return x * 2
|
||||
else:
|
||||
return x * 100
|
||||
|
||||
|
||||
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||
class MyWrapper(TorchCompileWithNoGuardsWrapper):
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||
super().__init__(
|
||||
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor): # type: ignore[override]
|
||||
# this is the function to be compiled
|
||||
return self.model(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
|
||||
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
# Create a proper vLLM config instead of mocking
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig()
|
||||
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
vllm_config.compilation_config.backend = "inductor"
|
||||
|
||||
# Test DYNAMO_TRACE_ONCE
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch._dynamo.reset()
|
||||
mod = MyMod()
|
||||
wrapper = MyWrapper(mod)
|
||||
|
||||
# First call should trigger compilation
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result1 = wrapper(x)
|
||||
expected1 = torch.tensor([2, 4, 6, 8])
|
||||
assert torch.allclose(result1, expected1), (
|
||||
f"Expected {expected1}, got {result1}"
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||
# this is the function to be compiled
|
||||
return self.model(x, cache)
|
||||
# Second call should use compiled code
|
||||
x2 = torch.tensor([1, 2, 3])
|
||||
result2 = wrapper(x2)
|
||||
expected2 = torch.tensor([2, 4, 6])
|
||||
assert torch.allclose(result2, expected2), (
|
||||
f"Expected {expected2}, got {result2}"
|
||||
)
|
||||
|
||||
def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||
# let torch.compile compile twice
|
||||
if len(self.compiled_codes) == 2:
|
||||
dispatch_id = 0 if cache is None else 1
|
||||
with self.dispatch_to_code(dispatch_id):
|
||||
return self.forward(x, cache)
|
||||
else:
|
||||
return self.compiled_callable(x, cache)
|
||||
# without the wrapper result would be different.
|
||||
result3 = mod(x2)
|
||||
expected3 = torch.tensor([100, 200, 300])
|
||||
|
||||
assert torch.allclose(result3, expected3), (
|
||||
f"Expected {result3}, got {expected3}"
|
||||
)
|
||||
|
||||
def test_torch_compile_wrapper():
|
||||
mod = MyMod()
|
||||
wrappers = []
|
||||
for i in range(3):
|
||||
torch._dynamo.reset()
|
||||
# with STOCK_TORCH_COMPILE we do not remove guards.
|
||||
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
|
||||
torch._dynamo.reset()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod = MyMod()
|
||||
wrapper = MyWrapper(mod)
|
||||
wrappers.append(wrapper)
|
||||
x = torch.tensor([1])
|
||||
wrapper(x, None) # profile run, compile
|
||||
# create a cache tensor
|
||||
cache = torch.tensor([2])
|
||||
wrapper(x, cache) # warm up with cache, recompile
|
||||
|
||||
# for new input, dispatch to the compiled code directly
|
||||
new_x = torch.tensor([3])
|
||||
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
|
||||
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||
# First call should trigger compilation
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
for wrapper in wrappers:
|
||||
# make sure they have independent compiled codes
|
||||
assert len(wrapper.compiled_codes) == 2
|
||||
result1 = wrapper(x)
|
||||
expected1 = torch.tensor([2, 4, 6, 8])
|
||||
assert torch.allclose(result1, expected1), (
|
||||
f"Expected {expected1}, got {result1}"
|
||||
)
|
||||
|
||||
# Second call should triger another compilation
|
||||
x2 = torch.tensor([1, 2, 3])
|
||||
result2 = wrapper(x2)
|
||||
expected2 = torch.tensor([100, 200, 300])
|
||||
assert torch.allclose(result2, expected2), (
|
||||
f"Expected {expected2}, got {result2}"
|
||||
)
|
||||
|
||||
# NO_COMPILATION level not supported.
|
||||
vllm_config.compilation_config.mode = None
|
||||
torch._dynamo.reset()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch._dynamo.reset()
|
||||
mod = MyMod()
|
||||
|
||||
try:
|
||||
wrapper = MyWrapper(mod)
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("expected an exception to be raised")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run with both parameter values
|
||||
|
||||
class MockMonkeypatch:
|
||||
def setenv(self, name, value):
|
||||
os.environ[name] = value
|
||||
|
||||
mp = MockMonkeypatch()
|
||||
|
||||
print("Testing with VLLM_USE_BYTECODE_HOOK=False")
|
||||
test_torch_compile_wrapper(False, mp)
|
||||
|
||||
print("Testing with VLLM_USE_BYTECODE_HOOK=True")
|
||||
test_torch_compile_wrapper(True, mp)
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
@ -34,6 +34,7 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_qwen2_5_vl_evs_functionality(
|
||||
vllm_runner,
|
||||
video_assets,
|
||||
@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality(
|
||||
num_frames: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
use_bytecode_hook: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Test EVS (Efficient Video Sampling) functionality with different
|
||||
pruning rates.
|
||||
"""
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
# Sample frames from video assets
|
||||
sampled_vids = [
|
||||
@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality(
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_qwen2_5_vl_evs_batched_videos(
|
||||
vllm_runner,
|
||||
video_assets,
|
||||
@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos(
|
||||
num_frames: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
use_bytecode_hook: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Test EVS functionality with batched videos.
|
||||
|
||||
@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos(
|
||||
2. Both pruning configurations work with multiple videos
|
||||
3. The model doesn't crash when processing multiple videos simultaneously
|
||||
"""
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
# Sample frames from video assets
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
|
||||
@ -75,6 +75,14 @@ def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_torch_dynamo():
|
||||
"""Reset torch dynamo cache before each test"""
|
||||
yield
|
||||
# Cleanup after test
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"speculative_config",
|
||||
[
|
||||
|
||||
@ -17,7 +17,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
@ -246,14 +246,14 @@ def _support_torch_compile(
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
@ -290,12 +290,43 @@ def _support_torch_compile(
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_mode=vllm_config.compilation_config.mode
|
||||
)
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
@ -303,6 +334,7 @@ def _support_torch_compile(
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, just call it.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
@ -362,120 +394,84 @@ def _support_torch_compile(
|
||||
)
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert not envs.VLLM_USE_AOT_COMPILE
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s", self.original_code_object)
|
||||
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename
|
||||
)
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
original_code_object.co_filename
|
||||
)
|
||||
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(
|
||||
aot_compilation_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
@ -4,11 +4,11 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
@ -17,88 +17,153 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
def _noop_add_global_state_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context():
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_callable: Callable | None = None,
|
||||
compilation_mode: CompilationMode = CompilationMode.NONE,
|
||||
):
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
mode = vllm_config.compilation_config.mode
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = (
|
||||
get_current_vllm_config().compilation_config.inductor_compile_config
|
||||
)
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
options = options or {}
|
||||
# This effectively drop all the guards.
|
||||
# We need this because bytecode hook is not used any more to
|
||||
# drop guards in the AOT compile mode.
|
||||
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward, fullgraph=True, backend=backend, options=options
|
||||
)
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: list[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = (
|
||||
compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode = None
|
||||
|
||||
def aot_compile(self, *args, **kwargs):
|
||||
if not hasattr(self.compiled_callable, "aot_compile"):
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
+ "Please make sure torch.compile is enabled with the latest "
|
||||
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self.compiled_callable.aot_compile((args, kwargs))
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile mode.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self.forward(*args, **kwargs)
|
||||
else:
|
||||
with _compilation_context():
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs): ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
@ -114,7 +179,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
@ -153,16 +218,21 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
def _dispatch_to_compiled_code(self):
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
|
||||
for more details.
|
||||
"""
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
|
||||
@ -92,6 +92,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||
VLLM_USE_AOT_COMPILE: bool = False
|
||||
VLLM_USE_BYTECODE_HOOK: bool = False
|
||||
VLLM_FORCE_AOT_LOAD: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||
@ -556,6 +557,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# compilation is done in warmup phase and the compilation will be
|
||||
# reused in subsequent calls.
|
||||
"VLLM_USE_AOT_COMPILE": use_aot_compile,
|
||||
# Feature flag to enable/disable bytecode in
|
||||
# TorchCompileWithNoGuardsWrapper.
|
||||
"VLLM_USE_BYTECODE_HOOK": lambda: bool(
|
||||
int(os.environ.get("VLLM_USE_BYTECODE_HOOK", "1"))
|
||||
),
|
||||
# Force vllm to always load AOT compiled models from disk. Failure
|
||||
# to load will result in a hard error when this is enabled.
|
||||
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
ParallelConfig,
|
||||
VllmConfig,
|
||||
@ -1895,12 +1895,14 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
compiled_model = self.model.get_language_model().model
|
||||
else:
|
||||
compiled_model = self.model.model
|
||||
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
|
||||
if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper):
|
||||
logger.info("Clear dynamo cache and cached dynamo bytecode.")
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
compiled_model.original_code_object
|
||||
compiled_model.original_code_object()
|
||||
)
|
||||
compiled_model.compiled_codes.clear()
|
||||
# Reset the wrapper to re-initialize.
|
||||
compiled_model.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(compiled_model)
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def select_hidden_states(self, hidden_states, indices_do_sample):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user