mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
661a34fd4f
commit
4fd9375028
@ -11,8 +11,8 @@ from torch.library import Library
|
|||||||
from vllm.compilation.compile_context import set_compile_context
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.config import VllmConfig
|
from vllm.plugins import set_current_vllm_config
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
|
|||||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||||
|
|
||||||
model = SillyModel(vllm_config=VllmConfig(), prefix='')
|
vllm_config = VllmConfig()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||||
|
|
||||||
inputs = torch.randn(100).cuda()
|
inputs = torch.randn(100).cuda()
|
||||||
|
|
||||||
|
|||||||
@ -15,12 +15,10 @@ from torch import nn
|
|||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
|
|
||||||
from vllm.compilation.compile_context import set_compile_context
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
from vllm.compilation.config import CompilationConfig
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.config import VllmConfig
|
from vllm.plugins import set_compilation_config, set_current_vllm_config
|
||||||
from vllm.plugins import set_compilation_config
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# create a library to hold the custom op
|
||||||
@ -272,9 +270,11 @@ def run_model(llama_config,
|
|||||||
CompilationLevel.NO_COMPILATION)
|
CompilationLevel.NO_COMPILATION)
|
||||||
set_compilation_config(None)
|
set_compilation_config(None)
|
||||||
|
|
||||||
model = LlamaModel(config=llama_config,
|
vllm_config = VllmConfig()
|
||||||
vllm_config=VllmConfig(),
|
with set_current_vllm_config(vllm_config):
|
||||||
prefix="").eval().cuda()
|
model = LlamaModel(config=llama_config,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix="").eval().cuda()
|
||||||
|
|
||||||
B = 16 # max batch size
|
B = 16 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
@ -395,9 +395,11 @@ def benchmark():
|
|||||||
else:
|
else:
|
||||||
set_compilation_config(None)
|
set_compilation_config(None)
|
||||||
|
|
||||||
model = LlamaModel(config=llama_config,
|
vllm_config = VllmConfig()
|
||||||
vllm_config=VllmConfig(),
|
with set_current_vllm_config(vllm_config):
|
||||||
prefix="").eval().cuda().to(torch.bfloat16)
|
model = LlamaModel(config=llama_config,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix="").eval().cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
B = 256 # max batch size
|
B = 256 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from ..utils import compare_all_settings
|
from ..utils import compare_all_settings
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
from ..utils import fork_new_process_for_each_test
|
||||||
from .utils import TEST_MODELS, check_full_graph_support
|
from .utils import TEST_MODELS, check_full_graph_support
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import torch
|
|||||||
from compressed_tensors.quantization import FP8_DTYPE
|
from compressed_tensors.quantization import FP8_DTYPE
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.config import CompilationConfig
|
|
||||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
||||||
find_auto_fn_maybe)
|
find_auto_fn_maybe)
|
||||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear)
|
apply_fp8_linear)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
|
|
||||||
class MyMod(torch.nn.Module):
|
class MyMod(torch.nn.Module):
|
||||||
@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||||
super().__init__(compiled_callable)
|
super().__init__(compiled_callable,
|
||||||
|
compilation_level=CompilationLevel.DYNAMO_ONCE)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
# this is the function to be compiled
|
# this is the function to be compiled
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
TEST_MODELS = [
|
TEST_MODELS = [
|
||||||
|
|||||||
@ -3,11 +3,13 @@ from typing import List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||||
ReLUSquaredActivation,
|
ReLUSquaredActivation,
|
||||||
SiluAndMul)
|
SiluAndMul)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.plugins import set_current_vllm_config
|
||||||
|
|
||||||
|
|
||||||
# Registered subclass for test
|
# Registered subclass for test
|
||||||
@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
|
|||||||
])
|
])
|
||||||
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
||||||
default_on: bool):
|
default_on: bool):
|
||||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
||||||
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
custom_ops=env.split(",")))
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
assert CustomOp.default_on() == default_on
|
||||||
|
|
||||||
# Reset default_on (computed once):
|
ops_enabled = [bool(x) for x in ops_enabled]
|
||||||
CustomOp.default_on.cache_clear()
|
|
||||||
|
|
||||||
assert CustomOp.default_on() == default_on
|
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||||
|
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||||
|
|
||||||
ops_enabled = [bool(x) for x in ops_enabled]
|
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||||
|
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||||
|
|
||||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||||
|
|
||||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
# If registered, subclasses should follow their own name
|
||||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
assert Relu3().enabled() == ops_enabled[3]
|
||||||
|
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||||
|
|
||||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
# Unregistered subclass
|
||||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
class SiluAndMul2(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
# If registered, subclasses should follow their own name
|
# Subclasses should not require registration
|
||||||
assert Relu3().enabled() == ops_enabled[3]
|
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
||||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
|
||||||
|
|
||||||
# Unregistered subclass
|
|
||||||
class SiluAndMul2(SiluAndMul):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Subclasses should not require registration
|
|
||||||
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
||||||
def test_enabled_ops_invalid(env: str):
|
def test_enabled_ops_invalid(env: str):
|
||||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
with pytest.raises(Exception): # noqa
|
||||||
CustomOp.default_on.cache_clear()
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
custom_ops=env.split(",")))
|
||||||
with pytest.raises(AssertionError):
|
with set_current_vllm_config(vllm_config):
|
||||||
RMSNorm(1024).enabled()
|
RMSNorm(1024).enabled()
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import tempfile
|
|||||||
|
|
||||||
import depyf
|
import depyf
|
||||||
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
# disable custom dispatcher, let Dynamo takes over
|
# disable custom dispatcher, let Dynamo takes over
|
||||||
# all the control
|
# all the control
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
from ..utils import compare_two_settings
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
|
|||||||
@ -10,13 +10,12 @@ import torch
|
|||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config import CompilationConfig, CompilationLevel
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||||
|
|
||||||
from .config import CompilationConfig
|
|
||||||
from .counter import compilation_counter
|
from .counter import compilation_counter
|
||||||
from .fusion import FusionPass
|
from .fusion import FusionPass
|
||||||
from .levels import CompilationLevel
|
|
||||||
from .reshapes import RedundantReshapesPass
|
from .reshapes import RedundantReshapesPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -392,7 +391,10 @@ class VllmBackend:
|
|||||||
sym_tensor_indices: List[int]
|
sym_tensor_indices: List[int]
|
||||||
input_buffers: List[torch.Tensor]
|
input_buffers: List[torch.Tensor]
|
||||||
|
|
||||||
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
def __init__(
|
||||||
|
self,
|
||||||
|
compilation_configs: CompilationConfig,
|
||||||
|
):
|
||||||
global global_graph_pool
|
global global_graph_pool
|
||||||
if global_graph_pool is None:
|
if global_graph_pool is None:
|
||||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||||
@ -401,11 +403,13 @@ class VllmBackend:
|
|||||||
# streams, it might not be safe to share a global pool.
|
# streams, it might not be safe to share a global pool.
|
||||||
# only investigate this when we use multiple streams
|
# only investigate this when we use multiple streams
|
||||||
self.graph_pool = global_graph_pool
|
self.graph_pool = global_graph_pool
|
||||||
self.post_grad_passes = post_grad_passes
|
self.post_grad_passes = []
|
||||||
|
|
||||||
self.sym_tensor_indices = []
|
self.sym_tensor_indices = []
|
||||||
self.input_buffers = []
|
self.input_buffers = []
|
||||||
|
|
||||||
|
self.compilation_configs = compilation_configs
|
||||||
|
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
@ -437,10 +441,10 @@ class VllmBackend:
|
|||||||
assert not self._called, "VllmBackend can only be called once"
|
assert not self._called, "VllmBackend can only be called once"
|
||||||
|
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
# config is read now, because only here can
|
# config is updated now, because only here can
|
||||||
# we get the sizes to capture for cudagraph
|
# we get the sizes to capture for cudagraph
|
||||||
# from compilation context
|
# from compilation context
|
||||||
self.compilation_configs = CompilationConfig.select_and_init_config()
|
self.compilation_configs.init_during_runtime()
|
||||||
self.add_passes_to_config()
|
self.add_passes_to_config()
|
||||||
|
|
||||||
self.split_gm, self.piecewise_graphs = split_graph(
|
self.split_gm, self.piecewise_graphs = split_graph(
|
||||||
@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
|
|||||||
return backend_str
|
return backend_str
|
||||||
assert level == CompilationLevel.PIECEWISE
|
assert level == CompilationLevel.PIECEWISE
|
||||||
|
|
||||||
return VllmBackend()
|
from vllm.plugins import get_current_vllm_config
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
return VllmBackend(compilation_config)
|
||||||
|
|||||||
@ -1,159 +0,0 @@
|
|||||||
import copy
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
from .compile_context import get_compile_context
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CompilationConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Configuration for compilation.
|
|
||||||
It has two parts:
|
|
||||||
- CudaGraph capture:
|
|
||||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
|
||||||
- False: cudagraph inside compilation is not used.
|
|
||||||
- True: cudagraph inside compilation is used. It requires
|
|
||||||
that all input buffers have fixed addresses.
|
|
||||||
Note that this is orthogonal to the cudagraph capture out
|
|
||||||
side of compilation.
|
|
||||||
TODO: move outside cudagraph logic into compilation.
|
|
||||||
torch.compile will handle cudagraph capture logic in the future.
|
|
||||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
|
||||||
- None: capture sizes are inferred from compilation context.
|
|
||||||
- List[int]: capture sizes are specified.
|
|
||||||
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
|
||||||
It means the first several runs will be treated as warmup runs.
|
|
||||||
Only after that, the execution will be recorded, and the recorded
|
|
||||||
cudagraph will be used for subsequent runs.
|
|
||||||
- cudagraph_copy_inputs: whether to copy input tensors for
|
|
||||||
cudagraph. If the caller can guarantee that the same input buffers
|
|
||||||
are always used, it can set this to False. Otherwise, it should
|
|
||||||
set this to True, and the compiler will copy the input to an
|
|
||||||
internally managed buffer. Default is False.
|
|
||||||
- Inductor compilation:
|
|
||||||
- use_inductor: whether to use inductor compilation.
|
|
||||||
- False: inductor compilation is not used. graph runs in eager.
|
|
||||||
- True: inductor compilation is used. one graph for symbolic shape
|
|
||||||
is compiled. In addition, compile for different sizes specified
|
|
||||||
in inductor_compile_sizes, using configurations
|
|
||||||
in inductor_compile_config.
|
|
||||||
- inductor_compile_sizes: sizes to compile for inductor.
|
|
||||||
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
|
|
||||||
to specialize inductor for cudagraph sizes no more than the
|
|
||||||
specified size. It is useful when we want to specialize inductor
|
|
||||||
with a subset of cudagraph sizes.
|
|
||||||
- inductor_compile_config: additional configurations for inductor.
|
|
||||||
- None: use default configurations.
|
|
||||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
|
||||||
from pass name to pass function qualified name. We use function
|
|
||||||
name because the config uses json format. If we pass the config
|
|
||||||
from Python, functions can also be passed directly via Python object
|
|
||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
|
||||||
- Custom inductor passes:
|
|
||||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
|
||||||
Each pass defines its own stages (before, after, maybe in-between).
|
|
||||||
- dump_graph_dir: directory to dump the graph. Default is .
|
|
||||||
- enable_fusion: whether to enable the custom fusion pass.
|
|
||||||
TODO better pass enabling system.
|
|
||||||
|
|
||||||
Why we have different sizes for cudagraph and inductor:
|
|
||||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
|
||||||
for the same size. We need to capture all the sizes we want to use.
|
|
||||||
- inductor: a graph compiled by inductor for a general shape can be used
|
|
||||||
for different sizes. Inductor can also compile for specific sizes,
|
|
||||||
where it can have more information to optimize the graph with fully
|
|
||||||
static shapes. However, we find the general shape compilation is
|
|
||||||
sufficient for most cases. It might be beneficial to compile for
|
|
||||||
certain small batchsizes, where inductor is good at optimizing.
|
|
||||||
"""
|
|
||||||
use_inductor: bool = True
|
|
||||||
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
|
|
||||||
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
|
|
||||||
inductor_compile_config: Dict = Field(default_factory=dict)
|
|
||||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
use_cudagraph: bool = False
|
|
||||||
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
|
||||||
cudagraph_num_of_warmups: int = 0
|
|
||||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
|
||||||
cudagraph_copy_inputs: bool = False
|
|
||||||
|
|
||||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
|
||||||
dump_graph_dir: Path = Field(default=Path("."))
|
|
||||||
enable_fusion: bool = True
|
|
||||||
|
|
||||||
# not configurable, computed after init
|
|
||||||
compile_sizes: List[int] = PrivateAttr
|
|
||||||
capture_sizes: List[int] = PrivateAttr
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
for k, v in self.inductor_passes.items():
|
|
||||||
if not isinstance(v, str):
|
|
||||||
assert callable(v), (
|
|
||||||
f"pass {k} should be a function or a qualified name")
|
|
||||||
self.inductor_compile_config[k] = v
|
|
||||||
continue
|
|
||||||
|
|
||||||
# resolve function from qualified name
|
|
||||||
names = v.split(".")
|
|
||||||
module = ".".join(names[:-1])
|
|
||||||
func_name = names[-1]
|
|
||||||
func = __import__(module).__dict__[func_name]
|
|
||||||
self.inductor_compile_config[k] = func
|
|
||||||
|
|
||||||
def init_during_runtime(self):
|
|
||||||
"""To complete the initialization of config,
|
|
||||||
we need to know the compile context, which is only available
|
|
||||||
during the first run of the model.
|
|
||||||
"""
|
|
||||||
context = get_compile_context()
|
|
||||||
context = copy.deepcopy(context) if context is not None else []
|
|
||||||
sizes_to_specialize: List[int] = context
|
|
||||||
if self.cudagraph_capture_sizes is None:
|
|
||||||
self.capture_sizes = sizes_to_specialize
|
|
||||||
else:
|
|
||||||
self.capture_sizes = self.cudagraph_capture_sizes
|
|
||||||
logger.info(("cudagraph sizes specified by model runner"
|
|
||||||
" %s is overridden by config %s"),
|
|
||||||
sizes_to_specialize, self.cudagraph_capture_sizes)
|
|
||||||
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
|
|
||||||
assert self.inductor_compile_sizes is None, (
|
|
||||||
"inductor_compile_sizes should be None when "
|
|
||||||
"inductor_specialize_for_cudagraph_no_more_than is not None")
|
|
||||||
self.compile_sizes = [
|
|
||||||
x for x in self.capture_sizes
|
|
||||||
if x <= self.inductor_specialize_for_cudagraph_no_more_than
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert self.inductor_compile_sizes is not None, (
|
|
||||||
"inductor_compile_sizes should not be None when "
|
|
||||||
"inductor_specialize_for_cudagraph_no_more_than is None")
|
|
||||||
self.compile_sizes = self.inductor_compile_sizes
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def select_and_init_config() -> "CompilationConfig":
|
|
||||||
"""The order of selecting config is:
|
|
||||||
1. Use the config specified in environment variable.
|
|
||||||
2. Use the config specified in plugins.
|
|
||||||
3. Use the default config.
|
|
||||||
"""
|
|
||||||
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
|
|
||||||
if config_path is not None:
|
|
||||||
with open(config_path) as json_file:
|
|
||||||
config = CompilationConfig.model_validate_json(
|
|
||||||
json_file.read())
|
|
||||||
else:
|
|
||||||
from vllm.plugins import get_compilation_config
|
|
||||||
predefined_config = get_compilation_config()
|
|
||||||
config = predefined_config if predefined_config is not None else (
|
|
||||||
CompilationConfig())
|
|
||||||
|
|
||||||
config.init_during_runtime()
|
|
||||||
return config
|
|
||||||
@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import supports_dynamo
|
||||||
@ -126,12 +124,14 @@ def _support_torch_compile(cls: type,
|
|||||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||||
# will handle the compilation, so we don't need to do anything here.
|
# will handle the compilation, so we don't need to do anything here.
|
||||||
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
|
self.do_not_compile = \
|
||||||
|
vllm_config.compilation_config.level in [
|
||||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||||
] or not supports_dynamo()
|
] or not supports_dynamo()
|
||||||
if self.do_not_compile:
|
if self.do_not_compile:
|
||||||
return
|
return
|
||||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||||
|
self, compilation_level=vllm_config.compilation_config.level)
|
||||||
|
|
||||||
cls.__init__ = __init__ # type: ignore
|
cls.__init__ = __init__ # type: ignore
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|||||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||||
fwd_only, register_replacement)
|
fwd_only, register_replacement)
|
||||||
|
|
||||||
from vllm.compilation.config import CompilationConfig
|
|
||||||
from vllm.compilation.inductor_pass import InductorPass
|
from vllm.compilation.inductor_pass import InductorPass
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.compilation.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
|
|||||||
@ -1,8 +0,0 @@
|
|||||||
# constants for the levels of the compilation process
|
|
||||||
|
|
||||||
|
|
||||||
class CompilationLevel:
|
|
||||||
NO_COMPILATION = 0
|
|
||||||
DYNAMO_AS_IS = 1
|
|
||||||
DYNAMO_ONCE = 2
|
|
||||||
PIECEWISE = 3
|
|
||||||
@ -8,8 +8,7 @@ from typing import Callable, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
from .levels import CompilationLevel
|
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileWrapperWithCustomDispatcher:
|
class TorchCompileWrapperWithCustomDispatcher:
|
||||||
@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
`torch.compile` over the forward method.
|
`torch.compile` over the forward method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, compiled_callable: Optional[Callable] = None):
|
def __init__(self,
|
||||||
|
compiled_callable: Optional[Callable] = None,
|
||||||
|
compilation_level: int = 0):
|
||||||
|
|
||||||
if compiled_callable is None:
|
if compiled_callable is None:
|
||||||
# default compilation settings
|
# default compilation settings
|
||||||
@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
backend = get_torch_compile_backend()
|
backend = get_torch_compile_backend()
|
||||||
if backend is None:
|
if backend is None:
|
||||||
from vllm.compilation.backends import select_default_backend
|
from vllm.compilation.backends import select_default_backend
|
||||||
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
|
backend = select_default_backend(compilation_level)
|
||||||
|
|
||||||
compiled_callable = torch.compile(
|
compiled_callable = torch.compile(
|
||||||
self.forward,
|
self.forward,
|
||||||
@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
# subclasses can use this to switch between the custom dispatcher
|
# subclasses can use this to switch between the custom dispatcher
|
||||||
# and the default Dynamo guard mechanism.
|
# and the default Dynamo guard mechanism.
|
||||||
self.use_custom_dispatcher: bool = \
|
self.use_custom_dispatcher: bool = \
|
||||||
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
|
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||||
|
|||||||
189
vllm/config.py
189
vllm/config.py
@ -3,10 +3,12 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
||||||
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -2052,6 +2054,185 @@ class ObservabilityConfig:
|
|||||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationLevel:
|
||||||
|
# constants for the levels of the compilation process
|
||||||
|
NO_COMPILATION = 0
|
||||||
|
DYNAMO_AS_IS = 1
|
||||||
|
DYNAMO_ONCE = 2
|
||||||
|
PIECEWISE = 3
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for compilation.
|
||||||
|
It has three parts:
|
||||||
|
- Top-level Compilation control:
|
||||||
|
- level: the level of compilation.
|
||||||
|
- 0: no compilation.
|
||||||
|
- 1: dynamo as is.
|
||||||
|
- 2: dynamo once.
|
||||||
|
- 3: piecewise compilation.
|
||||||
|
- custom_ops: fine-grained control over which custom ops to enable/disable.
|
||||||
|
Use 'all' to enable all, 'none' to disable all.
|
||||||
|
Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||||
|
or disable (prefixed with a '-').
|
||||||
|
Examples:
|
||||||
|
- 'all,-op1' to enable all except op1
|
||||||
|
- 'none,+op1,+op2' to enable only op1 and op2
|
||||||
|
By default, all custom ops are enabled when running without Inductor
|
||||||
|
and disabled when running with Inductor (compile_level >= Inductor).
|
||||||
|
- CudaGraph capture:
|
||||||
|
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||||
|
- False: cudagraph inside compilation is not used.
|
||||||
|
- True: cudagraph inside compilation is used. It requires
|
||||||
|
that all input buffers have fixed addresses.
|
||||||
|
Note that this is orthogonal to the cudagraph capture out
|
||||||
|
side of compilation.
|
||||||
|
TODO: move outside cudagraph logic into compilation.
|
||||||
|
torch.compile will handle cudagraph capture logic in the future.
|
||||||
|
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||||
|
- None: capture sizes are inferred from compilation context.
|
||||||
|
- List[int]: capture sizes are specified.
|
||||||
|
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
||||||
|
It means the first several runs will be treated as warmup runs.
|
||||||
|
Only after that, the execution will be recorded, and the recorded
|
||||||
|
cudagraph will be used for subsequent runs.
|
||||||
|
- cudagraph_copy_inputs: whether to copy input tensors for
|
||||||
|
cudagraph. If the caller can guarantee that the same input buffers
|
||||||
|
are always used, it can set this to False. Otherwise, it should
|
||||||
|
set this to True, and the compiler will copy the input to an
|
||||||
|
internally managed buffer. Default is False.
|
||||||
|
- Inductor compilation:
|
||||||
|
- use_inductor: whether to use inductor compilation.
|
||||||
|
- False: inductor compilation is not used. graph runs in eager.
|
||||||
|
- True: inductor compilation is used. one graph for symbolic shape
|
||||||
|
is compiled. In addition, compile for different sizes specified
|
||||||
|
in inductor_compile_sizes, using configurations
|
||||||
|
in inductor_compile_config.
|
||||||
|
- inductor_compile_sizes: sizes to compile for inductor.
|
||||||
|
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
|
||||||
|
to specialize inductor for cudagraph sizes no more than the
|
||||||
|
specified size. It is useful when we want to specialize inductor
|
||||||
|
with a subset of cudagraph sizes.
|
||||||
|
- inductor_compile_config: additional configurations for inductor.
|
||||||
|
- None: use default configurations.
|
||||||
|
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||||
|
from pass name to pass function qualified name. We use function
|
||||||
|
name because the config uses json format. If we pass the config
|
||||||
|
from Python, functions can also be passed directly via Python object
|
||||||
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||||
|
- custom inductor passes:
|
||||||
|
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||||
|
Each pass defines its own stages (before, after, maybe in-between).
|
||||||
|
- dump_graph_dir: directory to dump the graph. Default is .
|
||||||
|
- enable_fusion: whether to enable the custom fusion pass.
|
||||||
|
TODO better pass enabling system.
|
||||||
|
|
||||||
|
Why we have different sizes for cudagraph and inductor:
|
||||||
|
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||||
|
for the same size. We need to capture all the sizes we want to use.
|
||||||
|
- inductor: a graph compiled by inductor for a general shape can be used
|
||||||
|
for different sizes. Inductor can also compile for specific sizes,
|
||||||
|
where it can have more information to optimize the graph with fully
|
||||||
|
static shapes. However, we find the general shape compilation is
|
||||||
|
sufficient for most cases. It might be beneficial to compile for
|
||||||
|
certain small batchsizes, where inductor is good at optimizing.
|
||||||
|
""" # noqa
|
||||||
|
level: int = 0
|
||||||
|
custom_ops: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
use_inductor: bool = True
|
||||||
|
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
|
||||||
|
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
|
||||||
|
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||||
|
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
use_cudagraph: bool = False
|
||||||
|
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
||||||
|
cudagraph_num_of_warmups: int = 0
|
||||||
|
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||||
|
cudagraph_copy_inputs: bool = False
|
||||||
|
|
||||||
|
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||||
|
dump_graph_dir: Path = Field(default=Path("."))
|
||||||
|
enable_fusion: bool = True
|
||||||
|
|
||||||
|
# not configurable, computed after init
|
||||||
|
compile_sizes: List[int] = PrivateAttr
|
||||||
|
capture_sizes: List[int] = PrivateAttr
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
|
||||||
|
|
||||||
|
count_none = self.custom_ops.count("none")
|
||||||
|
count_all = self.custom_ops.count("all")
|
||||||
|
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||||
|
|
||||||
|
for k, v in self.inductor_passes.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
assert callable(v), (
|
||||||
|
f"pass {k} should be a function or a qualified name")
|
||||||
|
self.inductor_compile_config[k] = v
|
||||||
|
continue
|
||||||
|
|
||||||
|
# resolve function from qualified name
|
||||||
|
names = v.split(".")
|
||||||
|
module = ".".join(names[:-1])
|
||||||
|
func_name = names[-1]
|
||||||
|
func = __import__(module).__dict__[func_name]
|
||||||
|
self.inductor_compile_config[k] = func
|
||||||
|
|
||||||
|
def init_during_runtime(self):
|
||||||
|
"""To complete the initialization of config,
|
||||||
|
we need to know the compile context, which is only available
|
||||||
|
during the first run of the model.
|
||||||
|
"""
|
||||||
|
from vllm.compilation.compile_context import get_compile_context
|
||||||
|
context = get_compile_context()
|
||||||
|
context = copy.deepcopy(context) if context is not None else []
|
||||||
|
sizes_to_specialize: List[int] = context
|
||||||
|
if self.cudagraph_capture_sizes is None:
|
||||||
|
self.capture_sizes = sizes_to_specialize
|
||||||
|
else:
|
||||||
|
self.capture_sizes = self.cudagraph_capture_sizes
|
||||||
|
logger.info(("cudagraph sizes specified by model runner"
|
||||||
|
" %s is overridden by config %s"),
|
||||||
|
sizes_to_specialize, self.cudagraph_capture_sizes)
|
||||||
|
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
|
||||||
|
assert self.inductor_compile_sizes is None, (
|
||||||
|
"inductor_compile_sizes should be None when "
|
||||||
|
"inductor_specialize_for_cudagraph_no_more_than is not None")
|
||||||
|
self.compile_sizes = [
|
||||||
|
x for x in self.capture_sizes
|
||||||
|
if x <= self.inductor_specialize_for_cudagraph_no_more_than
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert self.inductor_compile_sizes is not None, (
|
||||||
|
"inductor_compile_sizes should not be None when "
|
||||||
|
"inductor_specialize_for_cudagraph_no_more_than is None")
|
||||||
|
self.compile_sizes = self.inductor_compile_sizes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_and_init_config() -> "CompilationConfig":
|
||||||
|
"""The order of selecting config is:
|
||||||
|
1. Use the config specified in environment variable.
|
||||||
|
2. Use the config specified in plugins.
|
||||||
|
3. Use the default config.
|
||||||
|
"""
|
||||||
|
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
|
||||||
|
if config_path is not None:
|
||||||
|
with open(config_path) as json_file:
|
||||||
|
config = CompilationConfig.model_validate_json(
|
||||||
|
json_file.read())
|
||||||
|
else:
|
||||||
|
from vllm.plugins import get_compilation_config
|
||||||
|
predefined_config = get_compilation_config()
|
||||||
|
config = predefined_config if predefined_config is not None else (
|
||||||
|
CompilationConfig())
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
"""Dataclass which contains all vllm-related configuration. This
|
"""Dataclass which contains all vllm-related configuration. This
|
||||||
@ -2073,6 +2254,8 @@ class VllmConfig:
|
|||||||
observability_config: Optional[ObservabilityConfig] = None
|
observability_config: Optional[ObservabilityConfig] = None
|
||||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||||
quant_config: Optional[QuantizationConfig] = None
|
quant_config: Optional[QuantizationConfig] = None
|
||||||
|
compilation_config: CompilationConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_quantization_config(
|
def _get_quantization_config(
|
||||||
@ -2133,6 +2316,12 @@ class VllmConfig:
|
|||||||
self.quant_config = VllmConfig._get_quantization_config(
|
self.quant_config = VllmConfig._get_quantization_config(
|
||||||
self.model_config, self.load_config)
|
self.model_config, self.load_config)
|
||||||
|
|
||||||
|
if self.compilation_config is None:
|
||||||
|
self.compilation_config = CompilationConfig.select_and_init_config(
|
||||||
|
)
|
||||||
|
|
||||||
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ("model=%r, speculative_config=%r, tokenizer=%r, "
|
return ("model=%r, speculative_config=%r, tokenizer=%r, "
|
||||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||||
|
|||||||
13
vllm/envs.py
13
vllm/envs.py
@ -69,7 +69,6 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||||
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
|
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
|
||||||
VLLM_CUSTOM_OPS: List[str] = []
|
|
||||||
VLLM_DISABLED_KERNELS: List[str] = []
|
VLLM_DISABLED_KERNELS: List[str] = []
|
||||||
VLLM_USE_V1: bool = False
|
VLLM_USE_V1: bool = False
|
||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
||||||
@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_TORCH_COMPILE_CONFIG":
|
"VLLM_TORCH_COMPILE_CONFIG":
|
||||||
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
|
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
|
||||||
|
|
||||||
# Fine-grained control over which custom ops to enable/disable.
|
|
||||||
# Use 'all' to enable all, 'none' to disable all.
|
|
||||||
# Also specify a list of custom op names to enable (prefixed with a '+'),
|
|
||||||
# or disable (prefixed with a '-').
|
|
||||||
# Examples:
|
|
||||||
# - 'all,-op1' to enable all except op1
|
|
||||||
# - 'none,+op1,+op2' to enable only op1 and op2
|
|
||||||
# By default, all custom ops are enabled when running without Inductor
|
|
||||||
# and disabled when running with Inductor (compile_level >= Inductor).
|
|
||||||
"VLLM_CUSTOM_OPS":
|
|
||||||
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
|
||||||
|
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
"LOCAL_RANK":
|
"LOCAL_RANK":
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
from functools import lru_cache
|
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.plugins import get_current_vllm_config
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -87,6 +85,8 @@ class CustomOp(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
# if no name, then it was not registered
|
# if no name, then it was not registered
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
custom_ops = compilation_config.custom_ops
|
||||||
if not hasattr(cls, "name"):
|
if not hasattr(cls, "name"):
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
f"Custom op {cls.__name__} was not registered, "
|
f"Custom op {cls.__name__} was not registered, "
|
||||||
@ -94,22 +94,25 @@ class CustomOp(nn.Module):
|
|||||||
f"It will be enabled/disabled based on the global settings.")
|
f"It will be enabled/disabled based on the global settings.")
|
||||||
return CustomOp.default_on()
|
return CustomOp.default_on()
|
||||||
|
|
||||||
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
|
enabled = f"+{cls.name}" in custom_ops
|
||||||
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
|
disabled = f"-{cls.name}" in custom_ops
|
||||||
assert not (enabled
|
assert not (enabled
|
||||||
and disabled), f"Cannot enable and disable {cls.name}"
|
and disabled), f"Cannot enable and disable {cls.name}"
|
||||||
|
|
||||||
return (CustomOp.default_on() or enabled) and not disabled
|
return (CustomOp.default_on() or enabled) and not disabled
|
||||||
|
|
||||||
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
|
|
||||||
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache
|
|
||||||
def default_on() -> bool:
|
def default_on() -> bool:
|
||||||
count_none = envs.VLLM_CUSTOM_OPS.count("none")
|
"""
|
||||||
count_all = envs.VLLM_CUSTOM_OPS.count("all")
|
On by default if level < CompilationLevel.PIECEWISE
|
||||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
Specifying 'all' or 'none' in custom_op takes precedence.
|
||||||
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
|
"""
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
custom_ops = compilation_config.custom_ops
|
||||||
|
count_none = custom_ops.count("none")
|
||||||
|
count_all = custom_ops.count("all")
|
||||||
|
return compilation_config.level < CompilationLevel.PIECEWISE and \
|
||||||
not count_none > 0 or count_all > 0
|
not count_none > 0 or count_all > 0
|
||||||
|
|
||||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator)
|
safetensors_weights_iterator)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.plugins import set_current_vllm_config
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
|||||||
all_params = [param.name for param in signatures.parameters.values()]
|
all_params = [param.name for param in signatures.parameters.values()]
|
||||||
if "vllm_config" in all_params and "prefix" in all_params:
|
if "vllm_config" in all_params and "prefix" in all_params:
|
||||||
# new-style model class
|
# new-style model class
|
||||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
with set_current_vllm_config(vllm_config):
|
||||||
|
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||||
"input arguments. Possibly you have an old-style model class"
|
"input arguments. Possibly you have an old-style model class"
|
||||||
" registered from out of tree and it is used for new vLLM version. "
|
" registered from out of tree and it is used for new vLLM version. "
|
||||||
@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
|||||||
kwargs["lora_config"] = vllm_config.lora_config
|
kwargs["lora_config"] = vllm_config.lora_config
|
||||||
if "scheduler_config" in all_params:
|
if "scheduler_config" in all_params:
|
||||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||||
return model_class(**kwargs)
|
with set_current_vllm_config(vllm_config):
|
||||||
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
import enum
|
import enum
|
||||||
import random
|
import random
|
||||||
from typing import NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
VllmConfig = None
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(enum.Enum):
|
class PlatformEnum(enum.Enum):
|
||||||
CUDA = enum.auto()
|
CUDA = enum.auto()
|
||||||
@ -129,6 +134,19 @@ class Platform:
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
"""
|
||||||
|
Check and update the configuration for the current platform.
|
||||||
|
|
||||||
|
It can raise an exception if the configuration is not compatible with
|
||||||
|
the current platform, or it can update the configuration to make it
|
||||||
|
compatible with the current platform.
|
||||||
|
|
||||||
|
The config is passed by reference, so it can be modified in place.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -1,18 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
|
||||||
from vllm.plugins import set_torch_compile_backend
|
from vllm.plugins import set_torch_compile_backend
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
if TYPE_CHECKING:
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
|
VllmConfig = None
|
||||||
"TPU does not support Inductor."
|
|
||||||
|
|
||||||
set_torch_compile_backend("openxla")
|
set_torch_compile_backend("openxla")
|
||||||
|
|
||||||
@ -31,3 +29,12 @@ class TpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||||
|
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||||
|
assert compilation_config.level < CompilationLevel.PIECEWISE,\
|
||||||
|
"TPU does not support Inductor."
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.compilation.config import CompilationConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
from vllm.config import VllmConfig
|
|
||||||
else:
|
else:
|
||||||
CompilationConfig = None
|
CompilationConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
|
|||||||
|
|
||||||
def get_compilation_config() -> Optional[CompilationConfig]:
|
def get_compilation_config() -> Optional[CompilationConfig]:
|
||||||
return _compilation_config
|
return _compilation_config
|
||||||
|
|
||||||
|
|
||||||
|
_current_vllm_config: Optional[VllmConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_current_vllm_config(vllm_config: VllmConfig):
|
||||||
|
"""
|
||||||
|
Temporarily set the current VLLM config.
|
||||||
|
Used during model initialization.
|
||||||
|
We save the current VLLM config in a global variable,
|
||||||
|
so that all modules can access it, e.g. custom ops
|
||||||
|
can access the VLLM config to determine how to dispatch.
|
||||||
|
"""
|
||||||
|
global _current_vllm_config
|
||||||
|
old_vllm_config = _current_vllm_config
|
||||||
|
try:
|
||||||
|
_current_vllm_config = vllm_config
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_current_vllm_config = old_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_vllm_config() -> VllmConfig:
|
||||||
|
assert _current_vllm_config is not None, "Current VLLM config is not set."
|
||||||
|
return _current_vllm_config
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||||
@ -8,11 +7,8 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.compilation.compile_context import set_compile_context
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
from vllm.compilation.config import CompilationConfig
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.compilation.levels import CompilationLevel
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -99,7 +95,7 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
and not self.model_config.enforce_eager)
|
||||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||||
@ -517,9 +513,9 @@ class GPUModelRunner:
|
|||||||
# CUDA graphs do not work properly with the custom CUDA kernels.
|
# CUDA graphs do not work properly with the custom CUDA kernels.
|
||||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
||||||
# and avoid any potential issues with the inductor.
|
# and avoid any potential issues with the inductor.
|
||||||
os.environ["VLLM_CUSTOM_OPS"] = "none"
|
|
||||||
set_compilation_config(
|
set_compilation_config(
|
||||||
CompilationConfig(
|
CompilationConfig(
|
||||||
|
custom_ops=["none"],
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
|
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
|
||||||
use_inductor=True,
|
use_inductor=True,
|
||||||
|
|||||||
@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
|||||||
from vllm.attention.backends.abstract import AttentionState
|
from vllm.attention.backends.abstract import AttentionState
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.compilation.compile_context import set_compile_context
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"provided. Defaulting to scaling factors of 1.0. "
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
"This may lead to less accurate results!")
|
"This may lead to less accurate results!")
|
||||||
|
|
||||||
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
|
if self.vllm_config.compilation_config.level ==\
|
||||||
and supports_dynamo():
|
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||||
from vllm.plugins import get_torch_compile_backend
|
from vllm.plugins import get_torch_compile_backend
|
||||||
backend = get_torch_compile_backend() or "eager"
|
backend = get_torch_compile_backend() or "eager"
|
||||||
self.model = torch.compile(
|
self.model = torch.compile(
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
model = get_model(vllm_config=self.vllm_config)
|
model = get_model(vllm_config=self.vllm_config)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
self.model = ModelWrapper(model)
|
self.model = ModelWrapper(model, self.vllm_config)
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
|
|
||||||
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module, vllm_config: VllmConfig):
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward,
|
compiled_callable = torch.compile(self.forward,
|
||||||
backend="openxla",
|
backend="openxla",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=False)
|
dynamic=False)
|
||||||
super().__init__(compiled_callable)
|
super().__init__(
|
||||||
|
compiled_callable,
|
||||||
|
compilation_level=vllm_config.compilation_config.level)
|
||||||
|
|
||||||
def __call__(self, *args, is_prompt: bool, **kwargs):
|
def __call__(self, *args, is_prompt: bool, **kwargs):
|
||||||
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
|
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user