mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:35:01 +08:00
Avoid bytecode hook and simplify TorchCompileWrapperWithCustomDipatch (#25110)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
5a84b76b86
commit
2e0ad629b0
@ -22,6 +22,8 @@ from vllm.config import (
|
|||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
from ...utils import create_new_process_for_each_test
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from .. import silly_attention # noqa: F401
|
from .. import silly_attention # noqa: F401
|
||||||
|
|
||||||
@ -193,7 +195,14 @@ def run_model(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||||
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_multi_graph_piecewise_compile(
|
||||||
|
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
|
||||||
|
):
|
||||||
|
# Set the environment variable for this test
|
||||||
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||||
|
|
||||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,8 @@ from vllm.config import (
|
|||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
from ...utils import create_new_process_for_each_test
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from ..silly_attention import get_global_counter, reset_global_counter
|
from ..silly_attention import get_global_counter, reset_global_counter
|
||||||
|
|
||||||
@ -124,6 +126,7 @@ def _run_simple_model(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
def test_simple_piecewise_compile(use_inductor):
|
def test_simple_piecewise_compile(use_inductor):
|
||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
|
|||||||
@ -29,6 +29,8 @@ from vllm.config import (
|
|||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
from ...utils import create_new_process_for_each_test
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from .. import silly_attention # noqa: F401
|
from .. import silly_attention # noqa: F401
|
||||||
|
|
||||||
@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
|||||||
("inductor", True), # Inductor, Inductor partition
|
("inductor", True), # Inductor, Inductor partition
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
def test_toy_llama(
|
def test_toy_llama(
|
||||||
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||||
):
|
):
|
||||||
@ -513,4 +516,8 @@ def benchmark():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Protect against subprocess reimport when using spawn_new_process_for_each_test
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
|
||||||
benchmark()
|
benchmark()
|
||||||
|
|||||||
@ -2,59 +2,134 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||||
from vllm.config import CompilationMode
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
CompilationMode,
|
||||||
|
VllmConfig,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyMod(torch.nn.Module):
|
class MyMod(torch.nn.Module):
|
||||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||||
if cache is not None:
|
if x.size()[0] >= 4:
|
||||||
return x + cache
|
|
||||||
return x * 2
|
return x * 2
|
||||||
|
else:
|
||||||
|
return x * 100
|
||||||
|
|
||||||
|
|
||||||
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class MyWrapper(TorchCompileWithNoGuardsWrapper):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
super().__init__()
|
||||||
super().__init__(
|
|
||||||
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
|
def forward(self, x: torch.Tensor): # type: ignore[override]
|
||||||
|
# this is the function to be compiled
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||||
|
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
|
||||||
|
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
|
||||||
|
# Set the environment variable for this test
|
||||||
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||||
|
|
||||||
|
# Create a proper vLLM config instead of mocking
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.compilation_config = CompilationConfig()
|
||||||
|
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||||
|
vllm_config.compilation_config.backend = "inductor"
|
||||||
|
|
||||||
|
# Test DYNAMO_TRACE_ONCE
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
mod = MyMod()
|
||||||
|
wrapper = MyWrapper(mod)
|
||||||
|
|
||||||
|
# First call should trigger compilation
|
||||||
|
x = torch.tensor([1, 2, 3, 4])
|
||||||
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|
||||||
|
result1 = wrapper(x)
|
||||||
|
expected1 = torch.tensor([2, 4, 6, 8])
|
||||||
|
assert torch.allclose(result1, expected1), (
|
||||||
|
f"Expected {expected1}, got {result1}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
# Second call should use compiled code
|
||||||
# this is the function to be compiled
|
x2 = torch.tensor([1, 2, 3])
|
||||||
return self.model(x, cache)
|
result2 = wrapper(x2)
|
||||||
|
expected2 = torch.tensor([2, 4, 6])
|
||||||
|
assert torch.allclose(result2, expected2), (
|
||||||
|
f"Expected {expected2}, got {result2}"
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
# without the wrapper result would be different.
|
||||||
# let torch.compile compile twice
|
result3 = mod(x2)
|
||||||
if len(self.compiled_codes) == 2:
|
expected3 = torch.tensor([100, 200, 300])
|
||||||
dispatch_id = 0 if cache is None else 1
|
|
||||||
with self.dispatch_to_code(dispatch_id):
|
|
||||||
return self.forward(x, cache)
|
|
||||||
else:
|
|
||||||
return self.compiled_callable(x, cache)
|
|
||||||
|
|
||||||
|
assert torch.allclose(result3, expected3), (
|
||||||
|
f"Expected {result3}, got {expected3}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_torch_compile_wrapper():
|
# with STOCK_TORCH_COMPILE we do not remove guards.
|
||||||
mod = MyMod()
|
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
|
||||||
wrappers = []
|
|
||||||
for i in range(3):
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
mod = MyMod()
|
||||||
wrapper = MyWrapper(mod)
|
wrapper = MyWrapper(mod)
|
||||||
wrappers.append(wrapper)
|
|
||||||
x = torch.tensor([1])
|
|
||||||
wrapper(x, None) # profile run, compile
|
|
||||||
# create a cache tensor
|
|
||||||
cache = torch.tensor([2])
|
|
||||||
wrapper(x, cache) # warm up with cache, recompile
|
|
||||||
|
|
||||||
# for new input, dispatch to the compiled code directly
|
# First call should trigger compilation
|
||||||
new_x = torch.tensor([3])
|
x = torch.tensor([1, 2, 3, 4])
|
||||||
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
|
|
||||||
|
|
||||||
for wrapper in wrappers:
|
result1 = wrapper(x)
|
||||||
# make sure they have independent compiled codes
|
expected1 = torch.tensor([2, 4, 6, 8])
|
||||||
assert len(wrapper.compiled_codes) == 2
|
assert torch.allclose(result1, expected1), (
|
||||||
|
f"Expected {expected1}, got {result1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second call should triger another compilation
|
||||||
|
x2 = torch.tensor([1, 2, 3])
|
||||||
|
result2 = wrapper(x2)
|
||||||
|
expected2 = torch.tensor([100, 200, 300])
|
||||||
|
assert torch.allclose(result2, expected2), (
|
||||||
|
f"Expected {expected2}, got {result2}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NO_COMPILATION level not supported.
|
||||||
|
vllm_config.compilation_config.mode = None
|
||||||
|
torch._dynamo.reset()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
mod = MyMod()
|
||||||
|
|
||||||
|
try:
|
||||||
|
wrapper = MyWrapper(mod)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
raise AssertionError("expected an exception to be raised")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run with both parameter values
|
||||||
|
|
||||||
|
class MockMonkeypatch:
|
||||||
|
def setenv(self, name, value):
|
||||||
|
os.environ[name] = value
|
||||||
|
|
||||||
|
mp = MockMonkeypatch()
|
||||||
|
|
||||||
|
print("Testing with VLLM_USE_BYTECODE_HOOK=False")
|
||||||
|
test_torch_compile_wrapper(False, mp)
|
||||||
|
|
||||||
|
print("Testing with VLLM_USE_BYTECODE_HOOK=True")
|
||||||
|
test_torch_compile_wrapper(True, mp)
|
||||||
|
|
||||||
|
print("All tests passed!")
|
||||||
|
|||||||
@ -34,6 +34,7 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
|
|||||||
@pytest.mark.parametrize("num_frames", [16])
|
@pytest.mark.parametrize("num_frames", [16])
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||||
def test_qwen2_5_vl_evs_functionality(
|
def test_qwen2_5_vl_evs_functionality(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
video_assets,
|
video_assets,
|
||||||
@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality(
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
use_bytecode_hook: bool,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test EVS (Efficient Video Sampling) functionality with different
|
"""Test EVS (Efficient Video Sampling) functionality with different
|
||||||
pruning rates.
|
pruning rates.
|
||||||
"""
|
"""
|
||||||
|
# Set the environment variable for this test
|
||||||
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||||
|
|
||||||
# Sample frames from video assets
|
# Sample frames from video assets
|
||||||
sampled_vids = [
|
sampled_vids = [
|
||||||
@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality(
|
|||||||
@pytest.mark.parametrize("num_frames", [16])
|
@pytest.mark.parametrize("num_frames", [16])
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||||
def test_qwen2_5_vl_evs_batched_videos(
|
def test_qwen2_5_vl_evs_batched_videos(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
video_assets,
|
video_assets,
|
||||||
@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos(
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
use_bytecode_hook: bool,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test EVS functionality with batched videos.
|
"""Test EVS functionality with batched videos.
|
||||||
|
|
||||||
@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos(
|
|||||||
2. Both pruning configurations work with multiple videos
|
2. Both pruning configurations work with multiple videos
|
||||||
3. The model doesn't crash when processing multiple videos simultaneously
|
3. The model doesn't crash when processing multiple videos simultaneously
|
||||||
"""
|
"""
|
||||||
|
# Set the environment variable for this test
|
||||||
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||||
# Sample frames from video assets
|
# Sample frames from video assets
|
||||||
sampled_vids = [
|
sampled_vids = [
|
||||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||||
|
|||||||
@ -75,6 +75,14 @@ def model_name():
|
|||||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_torch_dynamo():
|
||||||
|
"""Reset torch dynamo cache before each test"""
|
||||||
|
yield
|
||||||
|
# Cleanup after test
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"speculative_config",
|
"speculative_config",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationMode,
|
CompilationMode,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
@ -246,14 +246,14 @@ def _support_torch_compile(
|
|||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
"""
|
"""
|
||||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||||
# support decorating multiple times
|
# support decorating multiple times
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
# take care of method resolution order
|
# take care of method resolution order
|
||||||
# make sure super().__init__ is called on the base class
|
# make sure super().__init__ is called on the base class
|
||||||
# other than TorchCompileWrapperWithCustomDispatcher
|
# other than TorchCompileWithNoGuardsWrapper
|
||||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
|
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||||
|
|
||||||
old_init = cls.__init__
|
old_init = cls.__init__
|
||||||
|
|
||||||
@ -290,12 +290,43 @@ def _support_torch_compile(
|
|||||||
return
|
return
|
||||||
|
|
||||||
compilation_counter.num_models_seen += 1
|
compilation_counter.num_models_seen += 1
|
||||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
self.compiled = False
|
||||||
self, compilation_mode=vllm_config.compilation_config.mode
|
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||||
)
|
|
||||||
|
|
||||||
cls.__init__ = __init__
|
cls.__init__ = __init__
|
||||||
|
|
||||||
|
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||||
|
sig = inspect.signature(mod.__class__.forward)
|
||||||
|
bound_args = sig.bind(mod, *args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
for k, dims in dynamic_arg_dims.items():
|
||||||
|
arg = bound_args.arguments.get(k)
|
||||||
|
if arg is not None:
|
||||||
|
dims = [dims] if isinstance(dims, int) else dims
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
# In case dims is specified with negative indexing
|
||||||
|
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||||
|
torch._dynamo.mark_dynamic(arg, dims)
|
||||||
|
elif isinstance(arg, IntermediateTensors):
|
||||||
|
for tensor in arg.tensors.values():
|
||||||
|
# In case dims is specified with negative indexing
|
||||||
|
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||||
|
torch._dynamo.mark_dynamic(tensor, dims)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported dynamic dimensions"
|
||||||
|
f" {dims} for argument {k} with type {type(arg)}."
|
||||||
|
)
|
||||||
|
if mark_unbacked_dims:
|
||||||
|
for k, dims in mark_unbacked_dims.items():
|
||||||
|
arg = bound_args.arguments.get(k)
|
||||||
|
if arg is not None:
|
||||||
|
dims = [dims] if isinstance(dims, int) else dims
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
# In case dims is specified with negative indexing
|
||||||
|
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||||
|
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
# torch.compiler.is_compiling() means we are inside the compilation
|
# torch.compiler.is_compiling() means we are inside the compilation
|
||||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||||
@ -303,6 +334,7 @@ def _support_torch_compile(
|
|||||||
if self.do_not_compile or torch.compiler.is_compiling():
|
if self.do_not_compile or torch.compiler.is_compiling():
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
# if aot_compiled_fn is set, just call it.
|
||||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
@ -362,52 +394,23 @@ def _support_torch_compile(
|
|||||||
)
|
)
|
||||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.compiled:
|
||||||
|
assert not envs.VLLM_USE_AOT_COMPILE
|
||||||
|
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# This is the path for the first compilation.
|
||||||
|
|
||||||
# the first compilation needs to have dynamic shapes marked
|
# the first compilation needs to have dynamic shapes marked
|
||||||
if len(self.compiled_codes) < 1:
|
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||||
sig = inspect.signature(self.__class__.forward)
|
|
||||||
bound_args = sig.bind(self, *args, **kwargs)
|
|
||||||
bound_args.apply_defaults()
|
|
||||||
for k, dims in dynamic_arg_dims.items():
|
|
||||||
arg = bound_args.arguments.get(k)
|
|
||||||
if arg is not None:
|
|
||||||
dims = [dims] if isinstance(dims, int) else dims
|
|
||||||
if isinstance(arg, torch.Tensor):
|
|
||||||
# In case dims is specified with negative indexing
|
|
||||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
|
||||||
torch._dynamo.mark_dynamic(arg, dims)
|
|
||||||
elif isinstance(arg, IntermediateTensors):
|
|
||||||
for tensor in arg.tensors.values():
|
|
||||||
# In case dims is specified with negative indexing
|
|
||||||
dims = [
|
|
||||||
tensor.ndim + dim if dim < 0 else dim for dim in dims
|
|
||||||
]
|
|
||||||
torch._dynamo.mark_dynamic(tensor, dims)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported dynamic dimensions"
|
|
||||||
f" {dims} for argument {k} with type {type(arg)}."
|
|
||||||
)
|
|
||||||
if mark_unbacked_dims:
|
|
||||||
for k, dims in mark_unbacked_dims.items():
|
|
||||||
arg = bound_args.arguments.get(k)
|
|
||||||
if arg is not None:
|
|
||||||
dims = [dims] if isinstance(dims, int) else dims
|
|
||||||
if isinstance(arg, torch.Tensor):
|
|
||||||
# In case dims is specified with negative indexing
|
|
||||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
|
||||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
|
||||||
# here, it is the starting point of the `torch.compile` process
|
# here, it is the starting point of the `torch.compile` process
|
||||||
start_monitoring_torch_compile(self.vllm_config)
|
start_monitoring_torch_compile(self.vllm_config)
|
||||||
logger.debug("Start compiling function %s", self.original_code_object)
|
original_code_object = self.original_code_object()
|
||||||
|
logger.debug("Start compiling function %s", original_code_object)
|
||||||
|
|
||||||
# if we don't use custom dispatcher, we can directly call the
|
# we do not want tp delete the original code object entries since
|
||||||
# compiled function and let torch.compile handle the dispatching,
|
# we depend on them now to look up cached compiled functions.
|
||||||
# with the overhead of guard evaluation and recompilation.
|
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
|
||||||
# it seems Dynamo reuse the compilation across instances,
|
|
||||||
# while we need to make sure the compiled code is not reused.
|
|
||||||
# we need to control all the compilation of the model.
|
|
||||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
|
|
||||||
|
|
||||||
# collect all relevant files traced by Dynamo,
|
# collect all relevant files traced by Dynamo,
|
||||||
# so that the compilation cache can trigger re-compilation
|
# so that the compilation cache can trigger re-compilation
|
||||||
@ -415,7 +418,7 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
# 1. the file containing the top-level forward function
|
# 1. the file containing the top-level forward function
|
||||||
self.vllm_config.compilation_config.traced_files.add(
|
self.vllm_config.compilation_config.traced_files.add(
|
||||||
self.original_code_object.co_filename
|
original_code_object.co_filename
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. every time Dynamo sees a function call, it will inline
|
# 2. every time Dynamo sees a function call, it will inline
|
||||||
@ -457,9 +460,7 @@ def _support_torch_compile(
|
|||||||
assert cache_dir is not None
|
assert cache_dir is not None
|
||||||
try:
|
try:
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
self.aot_compiled_fn.save_compiled_function(
|
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||||
aot_compilation_path
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cannot save aot compilation to path %s, error: %s",
|
"Cannot save aot compilation to path %s, error: %s",
|
||||||
@ -467,15 +468,10 @@ def _support_torch_compile(
|
|||||||
str(e),
|
str(e),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self.compiled_callable(*args, **kwargs)
|
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||||
return output
|
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
self.compiled = True
|
||||||
# dispatch to the compiled code directly, without going through
|
return output
|
||||||
# the Dynamo guard mechanism.
|
|
||||||
with self.dispatch_to_code(0):
|
|
||||||
model_output = self.forward(*args, **kwargs)
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
cls.__call__ = __call__
|
cls.__call__ = __call__
|
||||||
return cls
|
return cls
|
||||||
|
|||||||
@ -4,11 +4,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from types import CodeType
|
from types import CodeType
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._C._dynamo.guards
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||||
@ -17,42 +17,94 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileWrapperWithCustomDispatcher:
|
def _noop_add_global_state_guard(self, *args, **kwargs):
|
||||||
|
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
|
||||||
|
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _compilation_context():
|
||||||
|
"""Context manager for compilation settings and patches.
|
||||||
|
|
||||||
|
This manager:
|
||||||
|
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||||
|
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||||
|
Generally a recompilation can happen whenever we use a new
|
||||||
|
backend instance in torch.compile.
|
||||||
|
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||||
|
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||||
|
TORCH_FUNCTION_MODE_STACK guards.
|
||||||
|
4. Restores everything when compilation completes
|
||||||
"""
|
"""
|
||||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
# Save original values
|
||||||
Subclasses should:
|
original_global_state_guard = (
|
||||||
1. Implement the forward method
|
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||||
2. Implement the dispatch logic in the __call__ method
|
)
|
||||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
original_torch_function_mode_stack_guard = (
|
||||||
and `with self.dispatch_to_code(index):` to dispatch to
|
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||||
the compiled code.
|
)
|
||||||
3. Implement the `__init__` method to determine how to call
|
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||||
`torch.compile` over the forward method.
|
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set higher cache limits for compilation
|
||||||
|
torch._dynamo.config.cache_size_limit = 2048
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||||
|
|
||||||
|
# Patch guard manager
|
||||||
|
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||||
|
_noop_add_global_state_guard
|
||||||
|
)
|
||||||
|
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||||
|
_noop_add_torch_function_mode_stack_guard
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Restore original values
|
||||||
|
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||||
|
original_global_state_guard
|
||||||
|
)
|
||||||
|
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||||
|
original_torch_function_mode_stack_guard
|
||||||
|
)
|
||||||
|
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileWithNoGuardsWrapper:
|
||||||
|
"""
|
||||||
|
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||||
|
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||||
|
When guards are dropped, the first time __call__ is invoked, a single
|
||||||
|
compilation is triggered. Dynamo should never be traced again after that
|
||||||
|
since we drop all guards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
self.compiled = False
|
||||||
compiled_callable: Callable | None = None,
|
|
||||||
compilation_mode: CompilationMode = CompilationMode.NONE,
|
|
||||||
):
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
if compiled_callable is None:
|
mode = vllm_config.compilation_config.mode
|
||||||
# default compilation settings
|
if mode is None:
|
||||||
# compiling the forward method
|
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||||
|
|
||||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||||
options = None
|
options = {}
|
||||||
|
|
||||||
if isinstance(backend, str) and backend == "inductor":
|
if isinstance(backend, str) and backend == "inductor":
|
||||||
options = (
|
options = vllm_config.compilation_config.inductor_compile_config
|
||||||
get_current_vllm_config().compilation_config.inductor_compile_config
|
|
||||||
)
|
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||||
|
# Drop all the guards.
|
||||||
|
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||||
|
|
||||||
if envs.VLLM_USE_AOT_COMPILE:
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
options = options or {}
|
|
||||||
# This effectively drop all the guards.
|
|
||||||
# We need this because bytecode hook is not used any more to
|
|
||||||
# drop guards in the AOT compile mode.
|
|
||||||
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
|
|
||||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||||
torch._dynamo.config.enable_aot_compile = True
|
torch._dynamo.config.enable_aot_compile = True
|
||||||
else:
|
else:
|
||||||
@ -61,44 +113,57 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
msg += "upgrade PyTorch version to use AOT compile."
|
msg += "upgrade PyTorch version to use AOT compile."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
compiled_callable = torch.compile(
|
self._compiled_callable = torch.compile(
|
||||||
self.forward, fullgraph=True, backend=backend, options=options
|
self.forward,
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False,
|
||||||
|
backend=backend,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compiled_callable = compiled_callable
|
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||||
self.original_code_object = self.__class__.forward.__code__
|
|
||||||
self.compiled_codes: list[CodeType] = []
|
|
||||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||||
|
self._compiled_bytecode = None
|
||||||
# read the env var to determine whether to use the custom dispatcher
|
|
||||||
# subclasses can use this to switch between the custom dispatcher
|
|
||||||
# and the default Dynamo guard mechanism.
|
|
||||||
self.use_custom_dispatcher: bool = (
|
|
||||||
compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
|
|
||||||
)
|
|
||||||
|
|
||||||
def aot_compile(self, *args, **kwargs):
|
def aot_compile(self, *args, **kwargs):
|
||||||
if not hasattr(self.compiled_callable, "aot_compile"):
|
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"aot_compile is not supported by the current configuration. "
|
"aot_compile is not supported by the current configuration. "
|
||||||
+ "Please make sure torch.compile is enabled with the latest "
|
+ "Please make sure torch.compile is enabled with the latest "
|
||||||
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
||||||
)
|
)
|
||||||
return self.compiled_callable.aot_compile((args, kwargs))
|
return self._compiled_callable.aot_compile((args, kwargs))
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Implement the dispatch logic here, beyond the torch.compile mode.
|
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||||
NOTE: this function can have additional arguments beyond the forward
|
if (
|
||||||
method, for directly dispatching to the compiled code.
|
self.vllm_config.compilation_config.mode
|
||||||
"""
|
== CompilationMode.STOCK_TORCH_COMPILE
|
||||||
return self.compiled_callable(*args, **kwargs)
|
):
|
||||||
|
return self._compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
|
if not self._compiled_bytecode:
|
||||||
|
# Make sure a compilation is triggered by clearing dynamo
|
||||||
|
# cache.
|
||||||
|
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||||
|
return self._compiled_callable(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
with self._dispatch_to_compiled_code():
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
with _compilation_context():
|
||||||
|
return self._compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, *args, **kwargs): ...
|
def forward(self, *args, **kwargs): ...
|
||||||
|
|
||||||
|
def original_code_object(self) -> CodeType:
|
||||||
|
"""Return the original code object of the forward method."""
|
||||||
|
return self.__class__.forward.__code__
|
||||||
|
|
||||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||||
"""Hook to save the compiled bytecode for direct execution."""
|
"""Hook to save the compiled bytecode for direct execution."""
|
||||||
if old_code is not self.original_code_object:
|
if old_code is not self.original_code_object():
|
||||||
return
|
return
|
||||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||||
frame = sys._getframe()
|
frame = sys._getframe()
|
||||||
@ -114,7 +179,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
if frame.f_locals["self"] is not self:
|
if frame.f_locals["self"] is not self:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.compiled_codes.append(new_code)
|
self._compiled_bytecode = new_code
|
||||||
|
|
||||||
path = self.vllm_config.compile_debug_dump_path()
|
path = self.vllm_config.compile_debug_dump_path()
|
||||||
if path:
|
if path:
|
||||||
@ -153,16 +218,21 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def dispatch_to_code(self, index: int):
|
def _dispatch_to_compiled_code(self):
|
||||||
"""Context manager to dispatch to the compiled code.
|
# noqa: E501
|
||||||
|
"""
|
||||||
|
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||||
Why does this work? Because Dynamo guarantees that the compiled
|
Why does this work? Because Dynamo guarantees that the compiled
|
||||||
bytecode has exactly the same arguments, cell variables, and free
|
bytecode has exactly the same arguments, cell variables, and free
|
||||||
variables as the original code. Therefore we can directly switch
|
variables as the original code. Therefore we can directly switch
|
||||||
the code object in the function and call it.
|
the code object in the function and call it.
|
||||||
|
|
||||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
|
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||||
for more details.
|
""" # noqa: E501 line too long
|
||||||
"""
|
original = self.original_code_object()
|
||||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
assert self._compiled_bytecode is not None
|
||||||
|
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||||
|
try:
|
||||||
yield
|
yield
|
||||||
self.__class__.forward.__code__ = self.original_code_object
|
finally:
|
||||||
|
self.__class__.forward.__code__ = original
|
||||||
|
|||||||
@ -92,6 +92,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||||
VLLM_USE_AOT_COMPILE: bool = False
|
VLLM_USE_AOT_COMPILE: bool = False
|
||||||
|
VLLM_USE_BYTECODE_HOOK: bool = False
|
||||||
VLLM_FORCE_AOT_LOAD: bool = False
|
VLLM_FORCE_AOT_LOAD: bool = False
|
||||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||||
@ -556,6 +557,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# compilation is done in warmup phase and the compilation will be
|
# compilation is done in warmup phase and the compilation will be
|
||||||
# reused in subsequent calls.
|
# reused in subsequent calls.
|
||||||
"VLLM_USE_AOT_COMPILE": use_aot_compile,
|
"VLLM_USE_AOT_COMPILE": use_aot_compile,
|
||||||
|
# Feature flag to enable/disable bytecode in
|
||||||
|
# TorchCompileWithNoGuardsWrapper.
|
||||||
|
"VLLM_USE_BYTECODE_HOOK": lambda: bool(
|
||||||
|
int(os.environ.get("VLLM_USE_BYTECODE_HOOK", "1"))
|
||||||
|
),
|
||||||
# Force vllm to always load AOT compiled models from disk. Failure
|
# Force vllm to always load AOT compiled models from disk. Failure
|
||||||
# to load will result in a hard error when this is enabled.
|
# to load will result in a hard error when this is enabled.
|
||||||
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
|
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.attention import Attention
|
|||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import MLAAttention
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
@ -1895,12 +1895,14 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
compiled_model = self.model.get_language_model().model
|
compiled_model = self.model.get_language_model().model
|
||||||
else:
|
else:
|
||||||
compiled_model = self.model.model
|
compiled_model = self.model.model
|
||||||
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
|
if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper):
|
||||||
logger.info("Clear dynamo cache and cached dynamo bytecode.")
|
logger.info("Clear dynamo cache and cached dynamo bytecode.")
|
||||||
torch._dynamo.eval_frame.remove_from_cache(
|
torch._dynamo.eval_frame.remove_from_cache(
|
||||||
compiled_model.original_code_object
|
compiled_model.original_code_object()
|
||||||
)
|
)
|
||||||
compiled_model.compiled_codes.clear()
|
# Reset the wrapper to re-initialize.
|
||||||
|
compiled_model.compiled = False
|
||||||
|
TorchCompileWithNoGuardsWrapper.__init__(compiled_model)
|
||||||
|
|
||||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||||
def select_hidden_states(self, hidden_states, indices_do_sample):
|
def select_hidden_states(self, hidden_states, indices_do_sample):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user