From ea2236bf95d25c517ae6afbda3a16fe92ee73e7a Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 9 May 2025 15:59:04 -0400 Subject: [PATCH] Add option to use torch._inductor.standalone_compile (#17057) Signed-off-by: rzou --- vllm/compilation/backends.py | 33 ++++-- vllm/compilation/compiler_interface.py | 141 +++++++++++++++++++++---- vllm/envs.py | 5 + 3 files changed, 150 insertions(+), 29 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1ff5fb1196b..c2e8c726c943 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.utils import weak_ref_tensors -from .compiler_interface import EagerAdaptor, InductorAdaptor +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager logger = init_logger(__name__) +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_TEST_STANDALONE_COMPILE: + logger.info("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.info("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.info("Using EagerAdaptor") + return EagerAdaptor() + + class CompilerManager: """ A manager to manage the compilation process, including @@ -41,11 +55,11 @@ class CompilerManager: support int as key. """ - def __init__(self, use_inductor: bool): + def __init__(self, compilation_config: CompilationConfig): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - cls = InductorAdaptor if use_inductor else EagerAdaptor - self.compiler = cls() self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @@ -123,8 +137,15 @@ class CompilerManager: # no compiler cached the graph, or the cache is disabled, # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape) + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" @@ -336,7 +357,7 @@ class VllmBackend: self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config.use_inductor) + self.compilation_config) # `torch.compile` is JIT compiled, so we don't need to # do anything here diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b7e7a79bef0b..423581784f7a 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -50,7 +50,8 @@ class CompilerInterface: graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, @@ -71,6 +72,10 @@ class CompilerInterface: If the compiler doesn't support caching, it should return None for the handle. If the compiler fails to compile the graph, it should return None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. """ return None, None @@ -127,23 +132,108 @@ class AlwaysHitShapeEnv: return "" +def get_inductor_factors() -> List[Any]: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> Tuple[Optional[Callable], Optional[Any]]: + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + class InductorAdaptor(CompilerInterface): """ - The adaptor for the Inductor compiler, version 2.5 and 2.6. + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: - factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) + factors = get_inductor_factors() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()[:10] return hash_str @@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: - current_config = {} from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) # disable remote cache current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - if compiler_config is not None: - current_config.update(compiler_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True + set_inductor_config(current_config, runtime_shape) # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface): return contextlib.nullcontext() +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + class EagerAdaptor(CompilerInterface): name = "eager" @@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. diff --git a/vllm/envs.py b/vllm/envs.py index 134cdf9905fa..d7f332cb0a73 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag to enable/disable Inductor standalone compile + "VLLM_TEST_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0", + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": @@ -805,6 +809,7 @@ def compute_hash() -> str: "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", + "VLLM_TEST_STANDALONE_COMPILE", ] for key in environment_variables_to_hash: if key in environment_variables: