mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 13:15:42 +08:00
[torch.compile] integration with compilation control (#9058)
This commit is contained in:
parent
78c0b4166c
commit
e4d652ea3e
@ -121,7 +121,9 @@ steps:
|
|||||||
- vllm/core/
|
- vllm/core/
|
||||||
- tests/distributed
|
- tests/distributed
|
||||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||||
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
|
|
||||||
@ -231,14 +233,16 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_full_graph_smoke.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
|
|
||||||
- label: "PyTorch Fullgraph Test" # 18min
|
# TODO: re-write in comparison tests, and fix symbolic shape
|
||||||
source_file_dependencies:
|
# for quantization ops.
|
||||||
- vllm/
|
# - label: "PyTorch Fullgraph Test" # 18min
|
||||||
- tests/compile
|
# source_file_dependencies:
|
||||||
commands:
|
# - vllm/
|
||||||
- pytest -v -s compile/test_full_graph.py
|
# - tests/compile
|
||||||
|
# commands:
|
||||||
|
# - pytest -v -s compile/test_full_graph.py
|
||||||
|
|
||||||
- label: Kernels Test %N # 1h each
|
- label: Kernels Test %N # 1h each
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
@ -394,7 +398,7 @@ steps:
|
|||||||
- tests/distributed/
|
- tests/distributed/
|
||||||
- vllm/compilation
|
- vllm/compilation
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||||
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||||
|
|||||||
48
tests/compile/test_basic_correctness.py
Normal file
48
tests/compile/test_basic_correctness.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
|
from ..utils import compare_all_settings
|
||||||
|
|
||||||
|
|
||||||
|
# we cannot afford testing the full Catesian product
|
||||||
|
# of all models and all levels
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
|
||||||
|
[
|
||||||
|
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
|
||||||
|
True),
|
||||||
|
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
|
||||||
|
["--quantization", "compressed-tensors"
|
||||||
|
], 1, 1, "FLASH_ATTN", "generate", True),
|
||||||
|
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
|
||||||
|
# TODO: add multi-modality test for llava
|
||||||
|
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
|
||||||
|
])
|
||||||
|
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
|
||||||
|
method, fullgraph):
|
||||||
|
# this test is run under multiple suits, with different GPUs.
|
||||||
|
# make sure we only run the test with correct CUDA devices.
|
||||||
|
# don't use "<", as it will duplicate the tests.
|
||||||
|
if cuda_device_count_stateless() != pp_size * tp_size:
|
||||||
|
pytest.skip("Not correct CUDA devices for the test.")
|
||||||
|
import os
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||||
|
if not fullgraph:
|
||||||
|
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
|
||||||
|
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
|
||||||
|
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
|
||||||
|
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
|
||||||
|
# inductor will change the output, so we cannot compare them.
|
||||||
|
all_envs: List[Optional[Dict[str, str]]] = [{
|
||||||
|
"VLLM_TORCH_COMPILE_LEVEL":
|
||||||
|
str(level)
|
||||||
|
} for level in [
|
||||||
|
CompilationLevel.NO_COMPILATION,
|
||||||
|
CompilationLevel.DYNAMO_AS_IS,
|
||||||
|
CompilationLevel.DYNAMO_ONCE,
|
||||||
|
]]
|
||||||
|
compare_all_settings(model, all_args, all_envs, method=method)
|
||||||
@ -1,13 +1,20 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.compilation.backends import vllm_backend
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
|
||||||
|
from ..utils import fork_new_process_for_each_test
|
||||||
from .utils import TEST_MODELS, check_full_graph_support
|
from .utils import TEST_MODELS, check_full_graph_support
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
||||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
@pytest.mark.parametrize(
|
||||||
def test_full_graph(model_info, backend):
|
"optimization_level",
|
||||||
|
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_full_graph(model_info, optimization_level):
|
||||||
model = model_info[0]
|
model = model_info[0]
|
||||||
model_kwargs = model_info[1]
|
model_kwargs = model_info[1]
|
||||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
check_full_graph_support(model,
|
||||||
|
model_kwargs,
|
||||||
|
optimization_level,
|
||||||
|
tp_size=1)
|
||||||
|
|||||||
@ -1,22 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.compilation.backends import vllm_backend
|
|
||||||
from vllm.utils import cuda_device_count_stateless
|
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
|
||||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
|
||||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
|
||||||
@fork_new_process_for_each_test
|
|
||||||
def test_full_graph_multi_gpu(model_info, tp_size, backend):
|
|
||||||
model = model_info[0]
|
|
||||||
model_kwargs = model_info[1]
|
|
||||||
|
|
||||||
# Skip the test if there are not enough CUDA devices.
|
|
||||||
if cuda_device_count_stateless() < tp_size:
|
|
||||||
pytest.skip("Not enough CUDA devices for the test.")
|
|
||||||
|
|
||||||
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.compilation.backends import vllm_backend
|
|
||||||
|
|
||||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
|
||||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
|
||||||
def test_full_graph(model_info, backend):
|
|
||||||
model = model_info[0]
|
|
||||||
model_kwargs = model_info[1]
|
|
||||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
|
||||||
@ -4,16 +4,9 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.plugins import set_torch_compile_backend
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
TEST_MODELS_SMOKE = [
|
|
||||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
|
||||||
"quantization": "compressed-tensors"
|
|
||||||
}),
|
|
||||||
("meta-llama/Meta-Llama-3-8B", {}),
|
|
||||||
]
|
|
||||||
|
|
||||||
TEST_MODELS = [
|
TEST_MODELS = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||||
@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
|
def check_full_graph_support(model,
|
||||||
|
model_kwargs,
|
||||||
|
optimization_level,
|
||||||
|
tp_size=1):
|
||||||
# make sure these models can be captured in full graph mode
|
# make sure these models can be captured in full graph mode
|
||||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
|
||||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||||
|
|
||||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
||||||
quantization = model_kwargs.get("quantization")
|
quantization = model_kwargs.get("quantization")
|
||||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
if (quantization == "fp8" or quantization == "gptq_marlin"
|
||||||
or quantization == "gptq_marlin_24") and backend != "eager":
|
or quantization == "gptq_marlin_24"
|
||||||
|
) and optimization_level >= CompilationLevel.INDUCTOR:
|
||||||
return
|
return
|
||||||
|
|
||||||
set_torch_compile_backend(backend)
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
|
|||||||
@ -5,9 +5,11 @@ import tempfile
|
|||||||
|
|
||||||
import depyf
|
import depyf
|
||||||
|
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
|
||||||
# disable custom dispatcher, let Dynamo takes over
|
# disable custom dispatcher, let Dynamo takes over
|
||||||
# all the control
|
# all the control
|
||||||
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
|
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
|
||||||
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
with depyf.prepare_debug(temp_dir):
|
with depyf.prepare_debug(temp_dir):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
|
||||||
from ..utils import compare_two_settings
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
# --enforce-eager on TPU causes graph compilation
|
# --enforce-eager on TPU causes graph compilation
|
||||||
@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
|
|||||||
|
|
||||||
|
|
||||||
def test_custom_dispatcher():
|
def test_custom_dispatcher():
|
||||||
compare_two_settings("google/gemma-2b",
|
compare_two_settings(
|
||||||
|
"google/gemma-2b",
|
||||||
arg1=["--enforce-eager"],
|
arg1=["--enforce-eager"],
|
||||||
arg2=["--enforce-eager"],
|
arg2=["--enforce-eager"],
|
||||||
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
|
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
|
||||||
env2={})
|
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
|
||||||
|
|||||||
@ -1,8 +1,17 @@
|
|||||||
|
import copy
|
||||||
import operator
|
import operator
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .compile_context import get_compile_context
|
||||||
|
from .levels import CompilationLevel
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fix_functionalization(graph: fx.Graph):
|
def fix_functionalization(graph: fx.Graph):
|
||||||
"""
|
"""
|
||||||
@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
|
|||||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||||
|
|
||||||
|
|
||||||
def vllm_backend(graph, example_inputs):
|
def wrap_inductor(graph, example_inputs, additional_inductor_config):
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
current_config = config.shallow_copy_dict()
|
current_config = config.shallow_copy_dict()
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
||||||
|
if additional_inductor_config is not None:
|
||||||
|
current_config.update(additional_inductor_config)
|
||||||
|
if current_config['post_grad_custom_post_pass'] is not None:
|
||||||
|
logger.warning(
|
||||||
|
"post_grad_custom_post_pass is already set in the config. "
|
||||||
|
"Overwriting it with the fix_functionalization")
|
||||||
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
||||||
return compile_fx(graph, example_inputs, config_patches=current_config)
|
return compile_fx(graph, example_inputs, config_patches=current_config)
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_backend(
|
||||||
|
graph,
|
||||||
|
example_inputs,
|
||||||
|
additional_inductor_config: Optional[Dict] = None) -> Callable:
|
||||||
|
|
||||||
|
context = get_compile_context()
|
||||||
|
context = copy.deepcopy(context) if context is not None else []
|
||||||
|
sizes_to_specialize: List[int] = context
|
||||||
|
|
||||||
|
# flags for all the seen shapes, whether we need to specialize
|
||||||
|
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
|
||||||
|
|
||||||
|
# if we need to specialize, the compiled graph for that shape
|
||||||
|
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
|
||||||
|
|
||||||
|
# this is the first compilation, we will compile a graph with
|
||||||
|
# dynamic shape, as the caller will mark first dimension as dynamic
|
||||||
|
logger.info("Compiling a graph for general shapes")
|
||||||
|
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
|
||||||
|
additional_inductor_config)
|
||||||
|
|
||||||
|
# TODO: Dynamo does not pass all dynamic shapes.
|
||||||
|
# Need to investigate why. It works now because all the dynamic
|
||||||
|
# shapes have the same value, and either of them can be used.
|
||||||
|
sym_shape_indices = [
|
||||||
|
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
|
||||||
|
]
|
||||||
|
|
||||||
|
first_run = True
|
||||||
|
|
||||||
|
# this is the function we return to Dynamo to run finally
|
||||||
|
def compiled_graph_wrapper(*args):
|
||||||
|
|
||||||
|
runtime_shapes: Tuple[int,
|
||||||
|
...] = tuple(args[i] for i in sym_shape_indices)
|
||||||
|
|
||||||
|
nonlocal first_run
|
||||||
|
nonlocal runtime_shapes_to_compile_flags
|
||||||
|
nonlocal runtime_shapes_to_compiled_graph
|
||||||
|
|
||||||
|
if first_run:
|
||||||
|
# the first compilation is for profiling, we directly run it
|
||||||
|
first_run = False
|
||||||
|
return graph_for_symbolic_shape(*args)
|
||||||
|
|
||||||
|
if runtime_shapes not in runtime_shapes_to_compile_flags:
|
||||||
|
# we haven't seen this shape before
|
||||||
|
# query if we need to specialize for this shape
|
||||||
|
# we only specialize for the first dimension.
|
||||||
|
# TODO: investigate if any model needs to specialize
|
||||||
|
# beyond the first dimension
|
||||||
|
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
|
||||||
|
0] in sizes_to_specialize
|
||||||
|
|
||||||
|
if not runtime_shapes_to_compile_flags[runtime_shapes]:
|
||||||
|
# we don't need to specialize for this shape
|
||||||
|
return graph_for_symbolic_shape(*args)
|
||||||
|
|
||||||
|
if runtime_shapes not in runtime_shapes_to_compiled_graph:
|
||||||
|
# we need to specialize for this shape, and we haven't compiled
|
||||||
|
# compile the graph for this shape
|
||||||
|
logger.info("Compiling a graph for shapes %s", runtime_shapes)
|
||||||
|
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
|
||||||
|
graph, args, additional_inductor_config)
|
||||||
|
|
||||||
|
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
|
||||||
|
|
||||||
|
return compiled_graph_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||||
|
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||||
|
backend = "eager"
|
||||||
|
return backend
|
||||||
|
assert level in [
|
||||||
|
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
|
||||||
|
], f"Invalid level {level}"
|
||||||
|
|
||||||
|
from vllm.compilation.backends import vllm_backend
|
||||||
|
from vllm.plugins import get_inductor_additional_configs
|
||||||
|
additional_configs = get_inductor_additional_configs()
|
||||||
|
|
||||||
|
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
|
||||||
|
if "max_autotune" in additional_configs and not additional_configs[
|
||||||
|
"max_autotune"]:
|
||||||
|
logger.warning(
|
||||||
|
"max_autotune is disabled, but is overridden by level %s",
|
||||||
|
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
|
||||||
|
additional_configs['max_autotune'] = True
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
backend = partial(vllm_backend,
|
||||||
|
additional_inductor_config=additional_configs)
|
||||||
|
|
||||||
|
return backend
|
||||||
|
|||||||
23
vllm/compilation/compile_context.py
Normal file
23
vllm/compilation/compile_context.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_compile_context: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_compile_context() -> Any:
|
||||||
|
"""Get the current compile context."""
|
||||||
|
return _compile_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_compile_context(context: Any):
|
||||||
|
"""A context manager that stores the current compile context,
|
||||||
|
usually it is a list of sizes to specialize.
|
||||||
|
"""
|
||||||
|
global _compile_context
|
||||||
|
prev_context = _compile_context
|
||||||
|
_compile_context = context
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_compile_context = prev_context
|
||||||
85
vllm/compilation/decorators.py
Normal file
85
vllm/compilation/decorators.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import supports_dynamo
|
||||||
|
|
||||||
|
|
||||||
|
def support_compile_llama_style(cls: type):
|
||||||
|
"""
|
||||||
|
A decorator to add support for compiling the forward method of a class.
|
||||||
|
If a module's **forward signature** is compatible with llama, this
|
||||||
|
decorator can be used to enable the compilation of the forward method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||||
|
# will handle the compilation, so we don't need to do anything here.
|
||||||
|
if envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||||
|
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||||
|
] or not supports_dynamo():
|
||||||
|
return cls
|
||||||
|
|
||||||
|
# take care of method resolution order
|
||||||
|
# make sure super().__init__ is called on the base class
|
||||||
|
# other than TorchCompileWrapperWithCustomDispatcher
|
||||||
|
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||||
|
|
||||||
|
old_init = cls.__init__
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
old_init(self, *args, **kwargs)
|
||||||
|
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||||
|
|
||||||
|
cls.__init__ = __init__
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
# torch.compiler.is_compiling() means we are inside the compilation
|
||||||
|
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||||
|
# need to compile the model inside.
|
||||||
|
if torch.compiler.is_compiling():
|
||||||
|
return self.forward(input_ids, positions, kv_caches, attn_metadata,
|
||||||
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
|
# the first compilation needs to have dynamic shapes marked
|
||||||
|
if len(self.compiled_codes) < 1:
|
||||||
|
if input_ids is not None:
|
||||||
|
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(positions, 0)
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
for tensors in intermediate_tensors.tensors.values():
|
||||||
|
torch._dynamo.mark_dynamic(tensors, 0)
|
||||||
|
|
||||||
|
# if we don't use custom dispatcher, we can directly call the
|
||||||
|
# compiled function and let torch.compile handle the dispatching,
|
||||||
|
# with the overhead of guard evaluation and recompilation.
|
||||||
|
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||||
|
return self.compiled_callable(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
|
||||||
|
# usually, capturing the model once is enough, and then we can
|
||||||
|
# dispatch to the compiled code directly, without going through
|
||||||
|
# the Dynamo guard mechanism.
|
||||||
|
with self.dispatch_to_code(0):
|
||||||
|
model_output = self.forward(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
cls.__call__ = __call__
|
||||||
|
return cls
|
||||||
9
vllm/compilation/levels.py
Normal file
9
vllm/compilation/levels.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# constants for the levels of the compilation process
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationLevel:
|
||||||
|
NO_COMPILATION = 0
|
||||||
|
DYNAMO_AS_IS = 1
|
||||||
|
DYNAMO_ONCE = 2
|
||||||
|
INDUCTOR = 3
|
||||||
|
INDUCTOR_MAX_AUTOTUNE = 4
|
||||||
@ -3,12 +3,14 @@ import sys
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from types import CodeType
|
from types import CodeType
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
from .levels import CompilationLevel
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileWrapperWithCustomDispatcher:
|
class TorchCompileWrapperWithCustomDispatcher:
|
||||||
"""
|
"""
|
||||||
@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
`torch.compile` over the forward method.
|
`torch.compile` over the forward method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, compiled_callable: Callable):
|
def __init__(self, compiled_callable: Optional[Callable] = None):
|
||||||
|
|
||||||
|
if compiled_callable is None:
|
||||||
|
# default compilation settings
|
||||||
|
# compiling the forward method
|
||||||
|
|
||||||
|
# choose the compile backend
|
||||||
|
|
||||||
|
# if the user has set the backend, use it
|
||||||
|
from vllm.plugins import get_torch_compile_backend
|
||||||
|
backend = get_torch_compile_backend()
|
||||||
|
if backend is None:
|
||||||
|
from vllm.compilation.backends import select_default_backend
|
||||||
|
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
|
||||||
|
|
||||||
|
compiled_callable = torch.compile(
|
||||||
|
self.forward,
|
||||||
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
backend=backend)
|
||||||
|
|
||||||
self.compiled_callable = compiled_callable
|
self.compiled_callable = compiled_callable
|
||||||
self.original_code_object = self.__class__.forward.__code__
|
self.original_code_object = self.__class__.forward.__code__
|
||||||
self.compiled_codes: List[CodeType] = []
|
self.compiled_codes: List[CodeType] = []
|
||||||
@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
# subclasses can use this to switch between the custom dispatcher
|
# subclasses can use this to switch between the custom dispatcher
|
||||||
# and the default Dynamo guard mechanism.
|
# and the default Dynamo guard mechanism.
|
||||||
self.use_custom_dispatcher: bool = \
|
self.use_custom_dispatcher: bool = \
|
||||||
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
|
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||||
|
|||||||
16
vllm/envs.py
16
vllm/envs.py
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
||||||
|
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
# Internal flag to enable Dynamo graph capture
|
|
||||||
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
|
|
||||||
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
|
|
||||||
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
|
|
||||||
lambda:
|
|
||||||
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
|
|
||||||
("true", "1")),
|
|
||||||
|
|
||||||
# Internal flag to control whether we use custom op,
|
|
||||||
# or use the native pytorch implementation
|
|
||||||
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
|
|
||||||
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
|
|
||||||
|
|
||||||
# Internal flag to enable Dynamo fullgraph capture
|
# Internal flag to enable Dynamo fullgraph capture
|
||||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
||||||
lambda: bool(
|
lambda: bool(
|
||||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||||
|
"VLLM_TORCH_COMPILE_LEVEL":
|
||||||
|
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
||||||
|
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_cpu, is_hip, is_xpu
|
from vllm.utils import is_cpu, is_hip, is_xpu
|
||||||
|
|
||||||
@ -55,7 +56,7 @@ class CustomOp(nn.Module):
|
|||||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
# specific backend. Currently, we do not support dynamic dispatching.
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
|
|
||||||
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
|
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from torch import nn
|
|||||||
from transformers import Gemma2Config
|
from transformers import Gemma2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_compile_llama_style
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_compile_llama_style
|
||||||
class Gemma2Model(nn.Module):
|
class Gemma2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_compile_llama_style
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_compile_llama_style
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.config.image_token_index)
|
self.config.image_token_index)
|
||||||
|
|
||||||
input_ids = None
|
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -1,7 +1,21 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.plugins import set_torch_compile_backend
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
|
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
||||||
|
|
||||||
|
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\
|
||||||
|
"TPU does not support Inductor."
|
||||||
|
|
||||||
|
set_torch_compile_backend("openxla")
|
||||||
|
|
||||||
|
|
||||||
class TpuPlatform(Platform):
|
class TpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.TPU
|
_enum = PlatformEnum.TPU
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|
||||||
@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
|
|||||||
|
|
||||||
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
||||||
return _torch_compile_backend
|
return _torch_compile_backend
|
||||||
|
|
||||||
|
|
||||||
|
_inductor_additional_configs: Dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_inductor_additional_configs(configs: Dict):
|
||||||
|
global _inductor_additional_configs
|
||||||
|
_inductor_additional_configs = configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_inductor_additional_configs() -> Dict:
|
||||||
|
return _inductor_additional_configs
|
||||||
|
|||||||
@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput(
|
|||||||
return self.embeddings == other.embeddings
|
return self.embeddings == other.embeddings
|
||||||
|
|
||||||
|
|
||||||
class IntermediateTensors(
|
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||||
msgspec.Struct,
|
@dataclass
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
class IntermediateTensors:
|
||||||
array_like=True): # type: ignore[call-arg]
|
|
||||||
"""For all pipeline stages except the last, we need to return the hidden
|
"""For all pipeline stages except the last, we need to return the hidden
|
||||||
states and residuals to be sent to the next stage. This data structure
|
states and residuals to be sent to the next stage. This data structure
|
||||||
contains the hidden states and residuals for a request.
|
contains the hidden states and residuals for a request.
|
||||||
|
|||||||
@ -18,6 +18,8 @@ import vllm.envs as envs
|
|||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.attention.backends.abstract import AttentionState
|
from vllm.attention.backends.abstract import AttentionState
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig)
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"provided. Defaulting to scaling factors of 1.0. "
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
"This may lead to less accurate results!")
|
"This may lead to less accurate results!")
|
||||||
|
|
||||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
|
||||||
from vllm.compilation.backends import vllm_backend
|
and supports_dynamo():
|
||||||
from vllm.plugins import get_torch_compile_backend
|
from vllm.plugins import get_torch_compile_backend
|
||||||
backend = get_torch_compile_backend() or vllm_backend
|
backend = get_torch_compile_backend() or "eager"
|
||||||
self.model = torch.compile(
|
self.model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
@ -1289,6 +1291,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
|
graph_batch_size = self.max_batchsize_to_capture
|
||||||
|
batch_size_capture_list = [
|
||||||
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||||
|
]
|
||||||
|
if self.model_config.enforce_eager:
|
||||||
|
batch_size_capture_list = []
|
||||||
|
with set_compile_context(batch_size_capture_list):
|
||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user