[Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops (#26113)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
Morrison Turnansky 2025-10-07 15:53:43 -04:00 committed by GitHub
parent eb577e4655
commit 0c824fc46f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 126 additions and 63 deletions

View File

@ -258,13 +258,13 @@ def tractable_computation(
@torch.inference_mode
def run_model(
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
llama_config, use_compile: bool, backend: str, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
@ -338,8 +338,8 @@ def run_model(
return output.cpu()
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(
@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)
if use_inductor:
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool):
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool):
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])

View File

@ -37,57 +37,59 @@ class Relu3(ReLUSquaredActivation):
@pytest.mark.parametrize(
"env, torch_level, use_inductor, ops_enabled, default_on",
"env, torch_level, backend, ops_enabled, default_on",
[
# Default values based on compile level
# - All by default (no Inductor compilation)
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
(None, 0, "eager", [True] * 4, True),
(None, 1, "eager", [True] * 4, True),
(None, 2, "eager", [True] * 4, True),
(None, 3, "eager", [True] * 4, True),
# - None by default (with Inductor)
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
(None, 0, "inductor", [True] * 4, True),
# - None by default (with Inductor)
(None, 1, "inductor", [False] * 4, False),
(None, 2, "inductor", [False] * 4, False),
(None, 3, "inductor", [False] * 4, False),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 3, False, [0, 1, 1, 1], True),
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
],
)
def test_enabled_ops(
env: Optional[str],
torch_level: int,
use_inductor: bool,
backend: str,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
backend=backend, level=torch_level, custom_ops=custom_ops
)
)
# breakpoint()
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

View File

@ -34,7 +34,7 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
@ -48,6 +48,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
)
logger.debug("Using EagerAdaptor")
return EagerAdaptor()

View File

@ -180,10 +180,11 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend: str = ""
backend: str = "inductor"
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend.
- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
@ -192,7 +193,11 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph)."""
(it sees a part of the graph). The backend can not be custom for compilation
level 3. Furthermore, compilation is only piecewise if splitting ops is set
accordingly and use_inductor_cudagraphs_partition is off. Note that the
default options for splitting ops are sufficient for piecewise compilation.
"""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
@ -210,8 +215,12 @@ class CompilationConfig:
compilation."""
# Inductor capture
use_inductor: bool = True
"""Whether to use inductor compilation:
use_inductor: Optional[bool] = None
"""
Whether to use inductor compilation.
This flag is deprecated and will be removed.
Please use the 'backend' option instead.
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
@ -219,7 +228,11 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE."""
This setting is ignored if level<PIECEWISE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
@ -523,7 +536,43 @@ class CompilationConfig:
"(where 'op' is the registered op name)"
)
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be\
removed in a future release."
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
self.backend = "inductor"
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -531,15 +580,15 @@ class CompilationConfig:
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
from vllm.compilation.backends import VllmBackend
@ -692,7 +741,7 @@ class CompilationConfig:
)
inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.use_inductor
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)

View File

@ -318,6 +318,25 @@ class VllmConfig:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
else:
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
assert self.compilation_config.level <= 3
# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
if (
self.compilation_config.custom_ops.count("none")
+ self.compilation_config.custom_ops.count("all")
== 0
):
if (
self.compilation_config.level > 0
and self.compilation_config.backend != "eager"
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")
# async tp is built on top of sequence parallelism
# and requires it to be enabled.

View File

@ -114,7 +114,9 @@ class CustomOp(nn.Module):
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
"Custom op %s was not registered, which means it won't appear\
in the op registry. It will be enabled/disabled based on the\
global settings.", # noqa: E501
cls.__name__,
)
return CustomOp.default_on()
@ -128,19 +130,17 @@ class CustomOp(nn.Module):
@staticmethod
def default_on() -> bool:
"""
On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
'all', off by default if 'none'.
When PyTorch Inductor is used, 'none' is the default value,
otherwise 'all'.
"""
from vllm.config import CompilationLevel
compilation_config = get_cached_compilation_config()
default_on = (
compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor
)
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
return default_on and not count_none > 0 or count_all > 0
assert count_none + count_all == 1
return not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.

View File

@ -274,8 +274,6 @@ class CpuPlatform(Platform):
"epilogue_fusion": True,
}
)
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION