Allow oot custom compiler extension via CompilerInterface (#28623)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Icey 2025-11-25 15:25:15 +08:00 committed by GitHub
parent fe3a4f5b34
commit 888152bf87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 24 deletions

View File

@ -63,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
)
elif compilation_config.backend == "eager":
logger.debug("Using EagerAdaptor")
return EagerAdaptor()
else:
logger.debug("Using custom backend: %s", compilation_config.backend)
compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
assert isinstance(compiler, CompilerInterface)
return compiler
class CompilerManager:
@ -545,7 +546,10 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
self.pass_manager = resolve_obj_by_qualname(
current_platform.get_pass_manager_cls()
)()
self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
@ -562,24 +566,20 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
self.post_grad_pass_manager.configure(self.vllm_config)
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
if self.pass_key in inductor_config:
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
# PassManager already added to config, make sure it's correct
assert (
inductor_config[PASS_KEY].uuid()
== self.post_grad_pass_manager.uuid()
)
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
assert isinstance(inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(inductor_config[self.pass_key])
inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs

View File

@ -331,9 +331,9 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
mode 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@ -768,7 +768,7 @@ class CompilationConfig:
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
self.backend = current_platform.simple_compile_backend
self.backend = current_platform.get_compile_backend()
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
@ -800,9 +800,7 @@ class CompilationConfig:
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
logger.info("Using OOT custom backend for compilation.")
from vllm.compilation.backends import VllmBackend

View File

@ -134,6 +134,11 @@ class Platform:
_global_graph_pool: Any | None = None
@property
def pass_key(self) -> str:
"""Inductor config key for the PassManager custom pass"""
return "post_grad_custom_post_pass"
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@ -177,6 +182,21 @@ class Platform:
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_pass_manager_cls(cls) -> str:
"""
Get the pass manager class for this platform.
It will be registered as a custom pass under the current_platform.pass_key.
"""
return "vllm.compilation.pass_manager.PostGradPassManager"
@classmethod
def get_compile_backend(cls) -> str:
"""
Get the custom compile backend for current platform.
"""
return cls.simple_compile_backend
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
# Treat empty device control env var as unset. This is a valid