mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
Allow oot custom compiler extension via CompilerInterface (#28623)
Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
fe3a4f5b34
commit
888152bf87
@ -63,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
else:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
assert compilation_config.backend == "eager", (
|
||||
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
elif compilation_config.backend == "eager":
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
else:
|
||||
logger.debug("Using custom backend: %s", compilation_config.backend)
|
||||
compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
|
||||
assert isinstance(compiler, CompilerInterface)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
@ -545,7 +546,10 @@ class VllmBackend:
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
self.pass_manager = resolve_obj_by_qualname(
|
||||
current_platform.get_pass_manager_cls()
|
||||
)()
|
||||
self.pass_key = current_platform.pass_key
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
@ -562,24 +566,20 @@ class VllmBackend:
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
if self.pass_key in inductor_config:
|
||||
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (
|
||||
inductor_config[PASS_KEY].uuid()
|
||||
== self.post_grad_pass_manager.uuid()
|
||||
)
|
||||
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
assert isinstance(inductor_config[self.pass_key], InductorPass)
|
||||
self.pass_manager.add(inductor_config[self.pass_key])
|
||||
inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs
|
||||
|
||||
@ -331,9 +331,9 @@ class CompilationConfig:
|
||||
We use string to avoid serialization issues when using compilation in a
|
||||
distributed setting. When the compilation mode is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation mode is 3, the backend is used for the piecewise compilation
|
||||
(it sees a part of the graph). The backend can not be custom for compilation
|
||||
mode 3, i.e. the backend must be either eager or inductor. Furthermore,
|
||||
compilation mode is 3, the backend supports both whole graph and piecewise
|
||||
compilation, available backends include eager, inductor, and custom backends,
|
||||
the latter of which can be defined via `get_compile_backend`. Furthermore,
|
||||
compilation is only piecewise if splitting ops is set accordingly and
|
||||
use_inductor_graph_partition is off. Note that the default options for
|
||||
splitting ops are sufficient for piecewise compilation.
|
||||
@ -768,7 +768,7 @@ class CompilationConfig:
|
||||
self.backend = "inductor" if self.use_inductor else "eager"
|
||||
|
||||
if self.backend == "":
|
||||
self.backend = current_platform.simple_compile_backend
|
||||
self.backend = current_platform.get_compile_backend()
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
|
||||
"""
|
||||
@ -800,9 +800,7 @@ class CompilationConfig:
|
||||
|
||||
assert self.mode == CompilationMode.VLLM_COMPILE
|
||||
if self.backend not in ["eager", "inductor"]:
|
||||
raise ValueError(
|
||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||
)
|
||||
logger.info("Using OOT custom backend for compilation.")
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
|
||||
@ -134,6 +134,11 @@ class Platform:
|
||||
|
||||
_global_graph_pool: Any | None = None
|
||||
|
||||
@property
|
||||
def pass_key(self) -> str:
|
||||
"""Inductor config key for the PassManager custom pass"""
|
||||
return "post_grad_custom_post_pass"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
"""Returns the supported dtypes for the current platform."""
|
||||
@ -177,6 +182,21 @@ class Platform:
|
||||
# all ROCm platforms for now.
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_pass_manager_cls(cls) -> str:
|
||||
"""
|
||||
Get the pass manager class for this platform.
|
||||
It will be registered as a custom pass under the current_platform.pass_key.
|
||||
"""
|
||||
return "vllm.compilation.pass_manager.PostGradPassManager"
|
||||
|
||||
@classmethod
|
||||
def get_compile_backend(cls) -> str:
|
||||
"""
|
||||
Get the custom compile backend for current platform.
|
||||
"""
|
||||
return cls.simple_compile_backend
|
||||
|
||||
@classmethod
|
||||
def device_id_to_physical_device_id(cls, device_id: int):
|
||||
# Treat empty device control env var as unset. This is a valid
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user