[torch.compile] integration with compilation control (#9058)

This commit is contained in:
youkaichao 2024-10-10 12:39:36 -07:00 committed by GitHub
parent 78c0b4166c
commit e4d652ea3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 404 additions and 98 deletions

View File

@ -121,7 +121,9 @@ steps:
- vllm/core/ - vllm/core/
- tests/distributed - tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4 - tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands: commands:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
@ -231,14 +233,16 @@ steps:
- vllm/ - vllm/
- tests/compile - tests/compile
commands: commands:
- pytest -v -s compile/test_full_graph_smoke.py - pytest -v -s compile/test_basic_correctness.py
- label: "PyTorch Fullgraph Test" # 18min # TODO: re-write in comparison tests, and fix symbolic shape
source_file_dependencies: # for quantization ops.
- vllm/ # - label: "PyTorch Fullgraph Test" # 18min
- tests/compile # source_file_dependencies:
commands: # - vllm/
- pytest -v -s compile/test_full_graph.py # - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 1h each - label: Kernels Test %N # 1h each
mirror_hardwares: [amd] mirror_hardwares: [amd]
@ -394,7 +398,7 @@ steps:
- tests/distributed/ - tests/distributed/
- vllm/compilation - vllm/compilation
commands: commands:
- pytest -v -s ./compile/test_full_graph_multi_gpu.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus - TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus

View File

@ -0,0 +1,48 @@
from typing import Dict, List, Optional
import pytest
from vllm.compilation.levels import CompilationLevel
from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings
# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
if not fullgraph:
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
all_envs: List[Optional[Dict[str, str]]] = [{
"VLLM_TORCH_COMPILE_LEVEL":
str(level)
} for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]]
compare_all_settings(model, all_args, all_envs, method=method)

View File

@ -1,13 +1,20 @@
import pytest import pytest
from vllm.compilation.backends import vllm_backend from vllm.compilation.levels import CompilationLevel
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support from .utils import TEST_MODELS, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS) @pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend]) @pytest.mark.parametrize(
def test_full_graph(model_info, backend): "optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
model = model_info[0] model = model_info[0]
model_kwargs = model_info[1] model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1) check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1)

View File

@ -1,22 +0,0 @@
import pytest
from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)

View File

@ -1,13 +0,0 @@
import pytest
from vllm.compilation.backends import vllm_backend
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

View File

@ -4,16 +4,9 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip from vllm.utils import is_hip
TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
TEST_MODELS = [ TEST_MODELS = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
})) }))
def check_full_graph_support(model, model_kwargs, backend, tp_size=1): def check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode # make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# Inductor doesn't support fp8/gptq_marlin_24 yet. # Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization") quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin" if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager": or quantization == "gptq_marlin_24"
) and optimization_level >= CompilationLevel.INDUCTOR:
return return
set_torch_compile_backend(backend)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",

View File

@ -5,9 +5,11 @@ import tempfile
import depyf import depyf
from vllm.compilation.levels import CompilationLevel
# disable custom dispatcher, let Dynamo takes over # disable custom dispatcher, let Dynamo takes over
# all the control # all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir): with depyf.prepare_debug(temp_dir):

View File

@ -1,5 +1,7 @@
import os import os
from vllm.compilation.levels import CompilationLevel
from ..utils import compare_two_settings from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation # --enforce-eager on TPU causes graph compilation
@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher(): def test_custom_dispatcher():
compare_two_settings("google/gemma-2b", compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"], arg1=["--enforce-eager"],
arg2=["--enforce-eager"], arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={}) env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})

View File

@ -1,8 +1,17 @@
import copy
import operator import operator
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.fx as fx import torch.fx as fx
from vllm.logger import init_logger
from .compile_context import get_compile_context
from .levels import CompilationLevel
logger = init_logger(__name__)
def fix_functionalization(graph: fx.Graph): def fix_functionalization(graph: fx.Graph):
""" """
@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f) # print(graph.python_code(root_module="self", verbose=True).src, file=f)
def vllm_backend(graph, example_inputs): def wrap_inductor(graph, example_inputs, additional_inductor_config):
from torch._inductor import config from torch._inductor import config
current_config = config.shallow_copy_dict() current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
current_config['post_grad_custom_post_pass'] = fix_functionalization current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config) return compile_fx(graph, example_inputs, config_patches=current_config)
def vllm_backend(
graph,
example_inputs,
additional_inductor_config: Optional[Dict] = None) -> Callable:
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)
# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
]
first_run = True
# this is the function we return to Dynamo to run finally
def compiled_graph_wrapper(*args):
runtime_shapes: Tuple[int,
...] = tuple(args[i] for i in sym_shape_indices)
nonlocal first_run
nonlocal runtime_shapes_to_compile_flags
nonlocal runtime_shapes_to_compiled_graph
if first_run:
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compile_flags:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize
if not runtime_shapes_to_compile_flags[runtime_shapes]:
# we don't need to specialize for this shape
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compiled_graph:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
return compiled_graph_wrapper
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"
from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True
from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)
return backend

View File

@ -0,0 +1,23 @@
from contextlib import contextmanager
from typing import Any
_compile_context: Any = None
def get_compile_context() -> Any:
"""Get the current compile context."""
return _compile_context
@contextmanager
def set_compile_context(context: Any):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global _compile_context
prev_context = _compile_context
_compile_context = context
try:
yield
finally:
_compile_context = prev_context

View File

@ -0,0 +1,85 @@
from typing import List, Optional, Union
import torch
import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
def support_compile_llama_style(cls: type):
"""
A decorator to add support for compiling the forward method of a class.
If a module's **forward signature** is compatible with llama, this
decorator can be used to enable the compilation of the forward method.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__
def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__
def __call__(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if torch.compiler.is_compiling():
return self.forward(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
if input_ids is not None:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(positions, 0)
if inputs_embeds is not None:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
if intermediate_tensors is not None:
for tensors in intermediate_tensors.tensors.values():
torch._dynamo.mark_dynamic(tensors, 0)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
return self.compiled_callable(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# the Dynamo guard mechanism.
with self.dispatch_to_code(0):
model_output = self.forward(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
cls.__call__ = __call__
return cls

View File

@ -0,0 +1,9 @@
# constants for the levels of the compilation process
class CompilationLevel:
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
INDUCTOR = 3
INDUCTOR_MAX_AUTOTUNE = 4

View File

@ -3,12 +3,14 @@ import sys
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from types import CodeType from types import CodeType
from typing import Callable, List from typing import Callable, List, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from .levels import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher: class TorchCompileWrapperWithCustomDispatcher:
""" """
@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method. `torch.compile` over the forward method.
""" """
def __init__(self, compiled_callable: Callable): def __init__(self, compiled_callable: Optional[Callable] = None):
if compiled_callable is None:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
self.compiled_callable = compiled_callable self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__ self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = [] self.compiled_codes: List[CodeType] = []
@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher # subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism. # and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \ self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level. """Implement the dispatch logic here, beyond the torch.compile level.

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
def get_default_cache_root(): def get_default_cache_root():
@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture # Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool( lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id

View File

@ -1,6 +1,7 @@
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu from vllm.utils import is_cpu, is_hip, is_xpu
@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one # NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching. # specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
return self.forward_native return self.forward_native
if is_hip(): if is_hip():

View File

@ -21,6 +21,7 @@ from torch import nn
from transformers import Gemma2Config from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_compile_llama_style
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_compile_llama_style
class Gemma2Model(nn.Module): class Gemma2Model(nn.Module):
def __init__( def __init__(

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_compile_llama_style
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_compile_llama_style
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(

View File

@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
input_ids = None input_ids = None
inputs_embeds = None inputs_embeds = None
else: else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
input_ids = None
else: else:
inputs_embeds = None inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -1,7 +1,21 @@
import os
import torch import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\
"TPU does not support Inductor."
set_torch_compile_backend("openxla")
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Callable, Optional, Union from typing import Callable, Dict, Optional, Union
import vllm.envs as envs import vllm.envs as envs
@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
def get_torch_compile_backend() -> Optional[Union[Callable, str]]: def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend return _torch_compile_backend
_inductor_additional_configs: Dict = {}
def set_inductor_additional_configs(configs: Dict):
global _inductor_additional_configs
_inductor_additional_configs = configs
def get_inductor_additional_configs() -> Dict:
return _inductor_additional_configs

View File

@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput(
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
class IntermediateTensors( # cannot use msgspec.Struct here because Dynamo does not support it
msgspec.Struct, @dataclass
omit_defaults=True, # type: ignore[call-arg] class IntermediateTensors:
array_like=True): # type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden """For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request. contains the hidden states and residuals for a request.

View File

@ -18,6 +18,8 @@ import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
from vllm.compilation.backends import vllm_backend and supports_dynamo():
from vllm.plugins import get_torch_compile_backend from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or vllm_backend backend = get_torch_compile_backend() or "eager"
self.model = torch.compile( self.model = torch.compile(
self.model, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
@ -1289,6 +1291,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size=batch_size, batch_size=batch_size,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.model_config.enforce_eager:
batch_size_capture_list = []
with set_compile_context(batch_size_capture_list):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return