diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8c98aa36ac0f..ed847a7e3696 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -229,6 +229,9 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py + # these tests need to be separated, cannot combine + - pytest -v -s compile/piecewise/test_simple.py + - pytest -v -s compile/piecewise/test_toy_llama.py - label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: diff --git a/tests/compile/piecewise/__init__.py b/tests/compile/piecewise/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json new file mode 100644 index 000000000000..03d077b76f62 --- /dev/null +++ b/tests/compile/piecewise/piecewise_compilation_config.json @@ -0,0 +1,4 @@ +{ + "use_cudagraph": true, + "non_cudagraph_ops": ["silly.attention"] +} \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py new file mode 100644 index 000000000000..a34d33efba1d --- /dev/null +++ b/tests/compile/piecewise/test_simple.py @@ -0,0 +1,96 @@ +""" +Test the piecewise compilation with a simple model so that we +can exactly calculate the expected output and side effects. +""" +import os + +import torch +from torch import nn + +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.levels import CompilationLevel + +os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) + +global_counter = 0 + + +@torch.library.custom_op("silly::attention", mutates_args=["out"]) +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + global global_counter + global_counter += 1 + print(f"{global_counter=}") + out.copy_(q) + out[0] += 1 + + +@silly_attention.register_fake +def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +@support_torch_compile +class SillyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overall effect: + x += 1 + x[0] += 2 + global_counter += 2 + """ + x = x + 1 + x = x + 2 + out = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, out) + x = out + x = x - 2 + x = x - 1 + out = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, out) + x = out + x = x + 1 + return x + + +def test_simple_piecewise_compile(): + + model = SillyModel() + + directory = os.path.dirname(__file__) + config = os.path.join(directory, "piecewise_compilation_config.json") + os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config + + input_buffer = torch.randn(100).cuda() + + with compilation_counter.expect( + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=5, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=3, # 1 + num_layers + num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen + num_cudagraph_caputured= + 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + + with set_compile_context([1, 2]): + model(input_buffer) + + model(input_buffer[:2]) + model(input_buffer[:1]) + + input_buffer[:2].zero_() + global global_counter + global_counter = 0 + output = model(input_buffer[:2]) + assert global_counter == 2 + assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + + # clean up to avoid side effects for other tests + del os.environ["VLLM_TORCH_COMPILE_CONFIG"] diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py new file mode 100644 index 000000000000..db6a983d70fe --- /dev/null +++ b/tests/compile/piecewise/test_toy_llama.py @@ -0,0 +1,334 @@ +""" +Test the piecewise compilation with a simple model, comparing the output +with and without the piecewise compilation. +""" +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.config import CompilationConfig +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.levels import CompilationLevel +from vllm.plugins import set_compilation_config + + +@torch.library.custom_op("silly::attention", mutates_args=["out"]) +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +@silly_attention.register_fake +def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +@dataclass +class LlamaConfig: + hidden_size: int = 128 + mlp_size: int = 256 + vocab_size: int = 128 + num_layers: int = 2 + + +class LlamaMLP(nn.Module): + + def __init__(self, config: LlamaConfig) -> None: + super().__init__() + self.gate_up_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.mlp_size * 2, + bias=False, + ) + self.down_projection = nn.Linear( + in_features=config.mlp_size, + out_features=config.hidden_size, + bias=False, + ) + + self.gate_up_projection.weight.data.fill_(0.0) + self.down_projection.weight.data.fill_(0.0) + + def forward(self, x): + x = self.gate_up_projection(x) + x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( + x[:, x.size(1) // 2:]) + x = self.down_projection(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__(self, config: LlamaConfig) -> None: + super().__init__() + self.qkv_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size * 3, + ) + + self.output_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size, + ) + + self.qkv_projection.weight.data.fill_(0.0) + self.output_projection.weight.data.fill_(0.0) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_projection(hidden_states) + hidden_size = qkv.size(-1) // 3 + q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1) + + q = q + positions.unsqueeze(1) + k = k + positions.unsqueeze(1) + + attn_output = torch.empty_like(q) + torch.ops.silly.attention(q, k, v, attn_output) + + output = self.output_projection(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig) -> None: + super().__init__() + self.self_attention = LlamaAttention(config) + self.mlp = LlamaMLP(config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = hidden_states / 2 + else: + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = hidden_states / 2 + + hidden_states = self.self_attention(positions=positions, + hidden_states=hidden_states) + + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = hidden_states / 2 + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__(self, config: LlamaConfig) -> None: + super().__init__() + self.embedding_tokens = nn.Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + ) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config) for _ in range(config.num_layers)]) + + self.embedding_tokens.weight.data.fill_(0.0) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embedding_tokens(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + return hidden_states + + +@torch.inference_mode +def run_model(llama_config, + use_compile: bool, + split_attn: bool = False) -> torch.Tensor: + + if use_compile: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( + CompilationLevel.PIECEWISE) + + if split_attn: + set_compilation_config( + CompilationConfig( + use_cudagraph=True, + non_cudagraph_ops=["silly.attention"], + )) + else: + set_compilation_config(CompilationConfig(use_cudagraph=True, )) + else: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( + CompilationLevel.NO_COMPILATION) + set_compilation_config(None) + + cls = LlamaModel + if use_compile: + cls = support_torch_compile(LlamaModel) + model = cls(llama_config).eval().cuda() + + B = 16 # max batch size + input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + positions = torch.arange(B).cuda() + + with set_compile_context([1, 2]): + model(input_ids, positions) + model(input_ids[:2], positions[:2]) + model(input_ids[:1], positions[:1]) + + input_ids[:2].zero_() + output = model(input_ids[:2], positions[:2]) + + # manual cleanup + del os.environ["VLLM_TORCH_COMPILE_LEVEL"] + set_compilation_config(None) + + return output.cpu() + + +def test_toy_llama(): + # compare output with and without piecewise compilation + + llama_config = LlamaConfig(hidden_size=128, + mlp_size=256, + vocab_size=128, + num_layers=2) + + outputs = [] + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_inductor_compilations=0, + num_cudagraph_caputured=0, + ): + outputs.append(run_model(llama_config, use_compile=False)) + with compilation_counter.expect( + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=1, + num_piecewise_capturable_graphs_seen=1, + num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen + num_cudagraph_caputured= + 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + outputs.append(run_model(llama_config, use_compile=True)) + + with compilation_counter.expect( + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=2 * llama_config.num_layers + + 1, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=1 + + llama_config.num_layers, # 1 + num_layers + num_inductor_compilations=1 + + llama_config.num_layers, # num_piecewise_capturable_graphs_seen + num_cudagraph_caputured=2 * + (1 + llama_config.num_layers + ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + outputs.append( + run_model(llama_config, use_compile=True, split_attn=True)) + + for i in range(1, len(outputs)): + assert torch.allclose(outputs[0], outputs[i]) + + +@torch.inference_mode +def benchmark(): + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) + from triton.testing import do_bench + cls = support_torch_compile(LlamaModel) + + # similar to llama 3.1-8B + llama_config = LlamaConfig(hidden_size=4096, + mlp_size=14336, + vocab_size=128 * 1024, + num_layers=32) + + # a tiny model to measure the overhead + # of piecewise cudagraph + llama_config = LlamaConfig(hidden_size=40, + mlp_size=80, + vocab_size=128, + num_layers=2) + + cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] + + eager_time = {} + full_cudagraph_time = {} + piecewise_cudagraph_time = {} + + pool = torch.cuda.graph_pool_handle() + + for piecewise in [False, True]: + if piecewise: + set_compilation_config( + CompilationConfig( + use_cudagraph=True, + non_cudagraph_ops=["silly.attention"], + )) + else: + set_compilation_config(None) + + model = cls(llama_config).eval().cuda().to(torch.bfloat16) + + B = 256 # max batch size + input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + positions = torch.arange(B).cuda().to(torch.bfloat16) + + graphs = {} + + with set_compile_context(cudagraph_sizes): + model(input_ids, positions) + for b in cudagraph_sizes[::-1]: + if not piecewise: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=pool): + output = model(input_ids[:b], positions[:b]) + graphs[b] = (graph, output) + else: + output = model(input_ids[:b], positions[:b]) + graphs[b] = (model, output) + for b in cudagraph_sizes: + if piecewise: + # noqa is for `Function definition does not bind loop variable` + # it will be problematic if we save the created lambda function + # and use it later, because it will look up the name `b` in the + # enclosing scope, and the value of `b` will always be 256. + # it is fine here, because we only use the lambda function once. + runtime = do_bench(lambda: graphs[b][0] # noqa + (input_ids[:b], positions[:b])) # noqa + piecewise_cudagraph_time[b] = runtime + else: + runtime = do_bench(lambda: graphs[b][0].replay()) # noqa + eager_runtime = do_bench( + lambda: model(input_ids[:b], positions[:b])) # noqa + full_cudagraph_time[b] = runtime + eager_time[b] = eager_runtime + + # print in tabular format + print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") + for b in cudagraph_sizes: + print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" + f"\t{piecewise_cudagraph_time[b]:.3f}")) + + +if __name__ == "__main__": + benchmark() diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f28f9145bb44..f00334934cb4 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -9,7 +9,7 @@ from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) @pytest.mark.parametrize( "optimization_level", - [CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR]) + [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE]) @fork_new_process_for_each_test def test_full_graph(model_info, optimization_level): model = model_info[0] diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 64fc08e80de3..95cad19126df 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -9,17 +9,19 @@ from vllm.platforms import current_platform TEST_MODELS = [ ("facebook/opt-125m", {}), - ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { - "dtype": torch.float16, - "quantization": "compressed-tensors" - }), + # TODO: add fake implementation for compressed-tensors + # ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { + # "dtype": torch.float16, + # "quantization": "compressed-tensors" + # }), ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { "dtype": torch.float16, "quantization": "fp8" }), - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { - "quantization": "compressed-tensors" - }), + # TODO: add fake implementation for compressed-tensors + # ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + # "quantization": "compressed-tensors" + # }), ("meta-llama/Meta-Llama-3-8B", {}), ] @@ -73,7 +75,7 @@ def check_full_graph_support(model, # much memory. quantization = model_kwargs.get("quantization") if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B") - and optimization_level >= CompilationLevel.INDUCTOR): + and optimization_level >= CompilationLevel.PIECEWISE): return prompts = [ diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 6d9832e2c39c..10cf49e19ecc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,13 +1,16 @@ import copy +import dataclasses import operator -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.fx as fx from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors -from .compile_context import get_compile_context +from .config import CompilationConfig +from .counter import compilation_counter from .levels import CompilationLevel logger = init_logger(__name__) @@ -157,113 +160,326 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -def wrap_inductor(graph, example_inputs, additional_inductor_config): +def wrap_inductor(graph, + example_inputs, + additional_inductor_config, + do_logging=False, + runtime_shape: Optional[int] = None, + use_inductor: bool = True): + if not use_inductor: + return graph + + compilation_counter.num_inductor_compilations += 1 + + if do_logging: + if runtime_shape is None: + logger.info("Compiling a graph for general shape") + else: + logger.info("Compiling a graph for shape %s", runtime_shape) + from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx if additional_inductor_config is not None: current_config.update(additional_inductor_config) - if current_config['post_grad_custom_post_pass'] is not None: - logger.warning( - "post_grad_custom_post_pass is already set in the config. " - "Overwriting it with the fix_functionalization") - current_config['post_grad_custom_post_pass'] = fix_functionalization + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) return compile_fx(graph, example_inputs, config_patches=current_config) -def vllm_backend( +@dataclasses.dataclass +class SplitItem: + submod_name: str + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph(graph: fx.GraphModule, + ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == 'call_function' and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( graph, - example_inputs, - additional_inductor_config: Optional[Dict] = None) -> Callable: + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True) - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context + outputs = [] - # flags for all the seen shapes, whether we need to specialize - runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + # sort the names to make sure the order is deterministic + names = [name for (name, module) in split_gm.named_modules()] + names.sort() - # if we need to specialize, the compiled graph for that shape - runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue - # this is the first compilation, we will compile a graph with - # dynamic shape, as the caller will mark first dimension as dynamic - logger.info("Compiling a graph for general shapes") - graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, - additional_inductor_config) + module = getattr(split_gm, name) - # TODO: Dynamo does not pass all dynamic shapes. - # Need to investigate why. It works now because all the dynamic - # shapes have the same value, and either of them can be used. - sym_shape_indices = [ - i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) - ] + graph_id = int(name.replace("submod_", "")) + outputs.append(SplitItem(name, graph_id in split_op_graphs, module)) - first_run = True + return split_gm, outputs - # this is the function we return to Dynamo to run finally - def compiled_graph_wrapper(*args): - runtime_shapes: Tuple[int, - ...] = tuple(args[i] for i in sym_shape_indices) +class VllmBackend: + """The compilation backend for `torch.compile` with VLLM. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. - nonlocal first_run - nonlocal runtime_shapes_to_compile_flags - nonlocal runtime_shapes_to_compiled_graph + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + """ - if first_run: - # the first compilation is for profiling, we directly run it - first_run = False - return graph_for_symbolic_shape(*args) + compilation_configs: CompilationConfig + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: List[SplitItem] + returned_callable: Callable - if runtime_shapes not in runtime_shapes_to_compile_flags: - # we haven't seen this shape before - # query if we need to specialize for this shape - # we only specialize for the first dimension. - # TODO: investigate if any model needs to specialize - # beyond the first dimension - runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[ - 0] in sizes_to_specialize + def __init__(self, ): + # every instance of VllmBackend has its own graph pool + self.graph_pool = torch.cuda.graph_pool_handle() - if not runtime_shapes_to_compile_flags[runtime_shapes]: - # we don't need to specialize for this shape - return graph_for_symbolic_shape(*args) + # `torch.compile` is JIT compiled, so we don't need to + # do anything here - if runtime_shapes not in runtime_shapes_to_compiled_graph: - # we need to specialize for this shape, and we haven't compiled - # compile the graph for this shape - logger.info("Compiling a graph for shapes %s", runtime_shapes) - runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( - graph, args, additional_inductor_config) + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + compilation_counter.num_graphs_seen += 1 - return compiled_graph_wrapper + # we control the compilation process, each instance can only be + # called once + assert not self._called, "VllmBackend can only be called once" + + self.graph = graph + # config is read now, because only here can + # we get the sizes to capture for cudagraph + # from compilation context + self.compilation_configs = CompilationConfig.select_and_init_config() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, self.compilation_configs.non_cudagraph_ops) + + returned_callable: Callable # type: ignore + + if len(self.piecewise_graphs) == 0: + compilation_counter.num_piecewise_graphs_seen += 1 + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + returned_callable = PiecewiseBackend(graph, + self.compilation_configs, + self.graph_pool, + is_first_graph=True) + else: + from torch._dynamo.utils import lazy_format_graph_code + logger.debug( + "%s", lazy_format_graph_code("stiching module", self.split_gm)) + + is_first_graph = True + + for item in self.piecewise_graphs: + compilation_counter.num_piecewise_graphs_seen += 1 + compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa + if not item.is_splitting_graph: + # cannot setattr to a module, so we need to set + # the attribute in the __dict__ + self.split_gm.__dict__[ + item.submod_name] = PiecewiseBackend( + item.graph, self.compilation_configs, + self.graph_pool, is_first_graph) + is_first_graph = False + returned_callable = self.split_gm + + self.returned_callable = returned_callable + # trigger the first compilation + # code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa + # to turn the inputs into fake tensors + import torch._guards + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode(example_inputs) + fake_args = [] + for arg in example_inputs: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor): + fake_args.append( + torch._dynamo.utils.to_fake_tensor(arg, fake_mode)) + else: + fake_args.append(arg) + self.returned_callable(*fake_args) + + self._called = True + + return self.returned_callable + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + +class PiecewiseBackend: + + def __init__(self, + graph: fx.GraphModule, + compilation_configs: CompilationConfig, + graph_pool: Any, + is_first_graph: bool = False): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_configs.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.compilation_configs = compilation_configs + self.graph_pool = graph_pool + self.is_first_graph = is_first_graph + + self.compile_sizes: Set[int] = set( + self.compilation_configs.compile_sizes) + self.capture_sizes: Set[int] = set( + self.compilation_configs.capture_sizes + ) if self.compilation_configs.use_cudagraph else set() + + self.compile_finished = False + self.first_run_finished = False + + self.compiled_graph_for_general_shape: Callable = None # type: ignore + + self.sym_shape_indices: List[int] = [] + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + for shape in self.compile_sizes.union(self.capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.capture_sizes, + ) + + def __call__(self, *args) -> Any: + + if not self.compile_finished: + self.compile_finished = True + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + + self.sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + + self.compiled_graph_for_general_shape = wrap_inductor( + self.graph, + args, + self.compilation_configs.inductor_compile_config, + runtime_shape=None, + do_logging=self.is_first_graph, + use_inductor=self.compilation_configs.use_inductor) + + return self.graph(*args) + + if not self.first_run_finished: + self.first_run_finished = True + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + # args are real arguments + entry.runnable = wrap_inductor( + self.graph, + args, + self.compilation_configs.inductor_compile_config, + runtime_shape=runtime_shape, + do_logging=self.is_first_graph, + use_inductor=self.compilation_configs.use_inductor) + + if not entry.use_cudagraph: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_configs.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + logger.info("Capturing a cudagraph for shape %s", + runtime_shape) + + cudagraph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + entry.output = weak_ref_tensors(entry.runnable(*args)) + + compilation_counter.num_cudagraph_caputured += 1 + + entry.cudagraph = cudagraph + return entry.output + + entry.cudagraph.replay() + return entry.output def select_default_backend(level: int) -> Union[str, Callable]: if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: backend_str = "eager" return backend_str - assert level in [ - CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE - ], f"Invalid level {level}" + assert level == CompilationLevel.PIECEWISE - from vllm.compilation.backends import vllm_backend - from vllm.plugins import get_inductor_additional_configs - additional_configs = get_inductor_additional_configs() - - if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE: - if "max_autotune" in additional_configs and not additional_configs[ - "max_autotune"]: - logger.warning( - "max_autotune is disabled, but is overridden by level %s", - CompilationLevel.INDUCTOR_MAX_AUTOTUNE) - additional_configs['max_autotune'] = True - - from functools import partial - backend = partial(vllm_backend, - additional_inductor_config=additional_configs) - - return backend + return VllmBackend() diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py new file mode 100644 index 000000000000..514f2b93ef64 --- /dev/null +++ b/vllm/compilation/config.py @@ -0,0 +1,154 @@ +import copy +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, PrivateAttr + +import vllm.envs as envs +from vllm.logger import init_logger + +from .compile_context import get_compile_context + +logger = init_logger(__name__) + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has two parts: + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_passes[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + from vllm.compilation.backends import fix_functionalization + from vllm.utils import combine_fx_passes + if "post_grad_custom_post_pass" in self.inductor_compile_config: + self.inductor_compile_config[ + "post_grad_custom_post_pass"] = combine_fx_passes( + fix_functionalization, + self.inductor_compile_config["post_grad_custom_post_pass"], + ) + else: + self.inductor_compile_config[ + "post_grad_custom_post_pass"] = fix_functionalization + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + config.init_during_runtime() + return config diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py new file mode 100644 index 000000000000..100a49aba74a --- /dev/null +++ b/vllm/compilation/counter.py @@ -0,0 +1,30 @@ +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_inductor_compilations: int = 0 + num_cudagraph_caputured: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}") + + +compilation_counter = CompilationCounter() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0449f9354d0a..3053e57e0b63 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -121,7 +121,10 @@ def _support_torch_compile(cls: type, # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher - cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__: + # support decorating multiple times + cls.__bases__ = cls.__bases__ + ( + TorchCompileWrapperWithCustomDispatcher, ) old_init = cls.__init__ # type: ignore @@ -160,6 +163,11 @@ def _support_torch_compile(cls: type, # compiled function and let torch.compile handle the dispatching, # with the overhead of guard evaluation and recompilation. if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) return self.compiled_callable(*args, **kwargs) # usually, capturing the model once is enough, and then we can diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py index 162bf5ae6499..19a3a2b52687 100644 --- a/vllm/compilation/levels.py +++ b/vllm/compilation/levels.py @@ -5,5 +5,4 @@ class CompilationLevel: NO_COMPILATION = 0 DYNAMO_AS_IS = 1 DYNAMO_ONCE = 2 - INDUCTOR = 3 - INDUCTOR_MAX_AUTOTUNE = 4 + PIECEWISE = 3 diff --git a/vllm/envs.py b/vllm/envs.py index ae6825f28007..b4a263d1e086 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -209,6 +209,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), + + # Path to the config file for torch compile + "VLLM_TORCH_COMPILE_CONFIG": + lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), + # Fine-grained control over which custom ops to enable/disable. # Use 'all' to enable all, 'none' to disable all. # Also specify a list of custom op names to enable (prefixed with a '+'), diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 83910339f3c9..764f4e9c99df 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -100,7 +100,7 @@ class CustomOp(nn.Module): return (CustomOp.default_on() or enabled) and not disabled - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR + # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod @lru_cache() @@ -108,7 +108,7 @@ class CustomOp(nn.Module): count_none = envs.VLLM_CUSTOM_OPS.count("none") count_all = envs.VLLM_CUSTOM_OPS.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ + return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8ba973b28263..8d0ce47df404 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -11,7 +11,7 @@ from .interface import Platform, PlatformEnum if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) -assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\ +assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ "TPU does not support Inductor." set_torch_compile_backend("openxla") diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 211fedbc6e2e..4338cbc37f6c 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,7 +1,8 @@ import logging -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import vllm.envs as envs +from vllm.compilation.config import CompilationConfig logger = logging.getLogger(__name__) @@ -44,13 +45,13 @@ def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend -_inductor_additional_configs: Dict = {} +_compilation_config: Optional[CompilationConfig] = None -def set_inductor_additional_configs(configs: Dict): - global _inductor_additional_configs - _inductor_additional_configs = configs +def set_compilation_config(config: Optional[CompilationConfig]): + global _compilation_config + _compilation_config = config -def get_inductor_additional_configs() -> Dict: - return _inductor_additional_configs +def get_compilation_config() -> Optional[CompilationConfig]: + return _compilation_config diff --git a/vllm/utils.py b/vllm/utils.py index fea318ebcdf4..90c4b8475781 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1479,6 +1479,15 @@ class LazyDict(Mapping, Generic[T]): return len(self._factory) +def combine_fx_passes(passes: List[Callable]) -> Callable: + + def combined_fx(graph) -> None: + for fx in passes: + fx(graph) + + return combined_fx + + def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. @@ -1486,3 +1495,19 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: but will not keep the original tensor alive. """ return torch.ops._C.weak_ref_tensor(tensor) + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] +) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors")