Add option to use torch._inductor.standalone_compile (#17057)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou 2025-05-09 15:59:04 -04:00 committed by GitHub
parent 7d4aedae7c
commit ea2236bf95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 29 deletions

View File

@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors from vllm.utils import weak_ref_tensors
from .compiler_interface import EagerAdaptor, InductorAdaptor from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor)
from .counter import compilation_counter from .counter import compilation_counter
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile from .monitor import end_monitoring_torch_compile
@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if envs.VLLM_TEST_STANDALONE_COMPILE:
logger.info("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
else:
logger.info("Using InductorAdaptor")
return InductorAdaptor()
else:
logger.info("Using EagerAdaptor")
return EagerAdaptor()
class CompilerManager: class CompilerManager:
""" """
A manager to manage the compilation process, including A manager to manage the compilation process, including
@ -41,11 +55,11 @@ class CompilerManager:
support int as key. support int as key.
""" """
def __init__(self, use_inductor: bool): def __init__(self, compilation_config: CompilationConfig):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
cls = InductorAdaptor if use_inductor else EagerAdaptor
self.compiler = cls()
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
@ -123,8 +137,15 @@ class CompilerManager:
# no compiler cached the graph, or the cache is disabled, # no compiler cached the graph, or the cache is disabled,
# we need to compile it # we need to compile it
if isinstance(self.compiler, InductorAdaptor):
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = \
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile( compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape) graph, example_inputs, additional_inductor_config, runtime_shape,
maybe_key)
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
@ -336,7 +357,7 @@ class VllmBackend:
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.compiler_manager: CompilerManager = CompilerManager( self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config.use_inductor) self.compilation_config)
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here

View File

@ -50,7 +50,8 @@ class CompilerInterface:
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: List[Any],
compiler_config: Dict[str, Any], compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> Tuple[Optional[Callable], Optional[Any]]:
""" """
Compile the graph with the given example inputs and compiler config, Compile the graph with the given example inputs and compiler config,
@ -71,6 +72,10 @@ class CompilerInterface:
If the compiler doesn't support caching, it should return None for the If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well. None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
""" """
return None, None return None, None
@ -127,13 +132,7 @@ class AlwaysHitShapeEnv:
return "" return ""
class InductorAdaptor(CompilerInterface): def get_inductor_factors() -> List[Any]:
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors: List[Any] = [] factors: List[Any] = []
# summarize system state # summarize system state
from torch._inductor.codecache import CacheBase from torch._inductor.codecache import CacheBase
@ -144,6 +143,97 @@ class InductorAdaptor(CompilerInterface):
from torch._inductor.codecache import torch_key from torch._inductor.codecache import torch_key
torch_factors = torch_key() torch_factors = torch_key()
factors.append(torch_factors) factors.append(torch_factors)
return factors
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.cache_dir = cache_dir
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
if isinstance(runtime_shape, int):
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"
from torch._inductor import standalone_compile
with pass_context(runtime_shape):
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config})
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
compiled_graph.save(path=path, format="unpacked")
return compiled_graph, (key, path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked")
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args):
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(), hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10] usedforsecurity=False).hexdigest()[:10]
return hash_str return hash_str
@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: List[Any],
compiler_config: Dict[str, Any], compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
# disable remote cache # disable remote cache
current_config["fx_graph_cache"] = True current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False current_config["fx_graph_remote_cache"] = False
if compiler_config is not None: set_inductor_config(current_config, runtime_shape)
current_config.update(compiler_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
# inductor can inplace modify the graph, so we need to copy it # inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980 # see https://github.com/pytorch/pytorch/issues/138980
@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext() return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True
class EagerAdaptor(CompilerInterface): class EagerAdaptor(CompilerInterface):
name = "eager" name = "eager"
@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: List[Any],
compiler_config: Dict[str, Any], compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> Tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself. # we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle. # It does not support caching, return None for the handle.

View File

@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool( lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
# Internal flag to enable/disable Inductor standalone compile
"VLLM_TEST_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":
@ -805,6 +809,7 @@ def compute_hash() -> str:
"VLLM_USE_TRITON_AWQ", "VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK", "VLLM_DP_RANK",
"VLLM_DP_SIZE", "VLLM_DP_SIZE",
"VLLM_TEST_STANDALONE_COMPILE",
] ]
for key in environment_variables_to_hash: for key in environment_variables_to_hash:
if key in environment_variables: if key in environment_variables: