Revert #26113 "[Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops" (#26472)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-10-09 20:43:55 +08:00 committed by GitHub
parent 92be3f3517
commit 5728da11ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 126 deletions

View File

@ -258,13 +258,13 @@ def tractable_computation(
@torch.inference_mode
def run_model(
llama_config, use_compile: bool, backend: str, split_attn: bool = False
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
backend=backend,
use_inductor=use_inductor,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
@ -338,8 +338,8 @@ def run_model(
return output.cpu()
@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(
@ -358,10 +358,10 @@ def test_toy_llama(backend: str):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
if backend == "inductor":
if use_inductor:
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@ -377,8 +377,10 @@ def test_toy_llama(backend: str):
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)
outputs.append(
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@ -393,9 +395,16 @@ def test_toy_llama(backend: str):
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])

View File

@ -37,59 +37,57 @@ class Relu3(ReLUSquaredActivation):
@pytest.mark.parametrize(
"env, torch_level, backend, ops_enabled, default_on",
"env, torch_level, use_inductor, ops_enabled, default_on",
[
# Default values based on compile level
# - All by default (no Inductor compilation)
(None, 0, "eager", [True] * 4, True),
(None, 1, "eager", [True] * 4, True),
(None, 2, "eager", [True] * 4, True),
(None, 3, "eager", [True] * 4, True),
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
# - None by default (with Inductor)
(None, 0, "inductor", [True] * 4, True),
# - None by default (with Inductor)
(None, 1, "inductor", [False] * 4, False),
(None, 2, "inductor", [False] * 4, False),
(None, 3, "inductor", [False] * 4, False),
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
("-rms_norm", 3, False, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
],
)
def test_enabled_ops(
env: Optional[str],
torch_level: int,
backend: str,
use_inductor: bool,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
backend=backend, level=torch_level, custom_ops=custom_ops
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
)
)
# breakpoint()
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

View File

@ -34,7 +34,7 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.backend == "inductor":
if compilation_config.use_inductor:
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
@ -48,10 +48,6 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
)
logger.debug("Using EagerAdaptor")
return EagerAdaptor()

View File

@ -180,11 +180,10 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend: str = "inductor"
backend: str = ""
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
@ -193,11 +192,7 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
level 3. Furthermore, compilation is only piecewise if splitting ops is set
accordingly and use_inductor_cudagraphs_partition is off. Note that the
default options for splitting ops are sufficient for piecewise compilation.
"""
(it sees a part of the graph)."""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
@ -215,12 +210,8 @@ class CompilationConfig:
compilation."""
# Inductor capture
use_inductor: Optional[bool] = None
"""
Whether to use inductor compilation.
This flag is deprecated and will be removed.
Please use the 'backend' option instead.
use_inductor: bool = True
"""Whether to use inductor compilation:
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
@ -228,11 +219,7 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
@ -538,43 +525,7 @@ class CompilationConfig:
"(where 'op' is the registered op name)"
)
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be\
removed in a future release."
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
self.backend = "inductor"
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -582,15 +533,15 @@ class CompilationConfig:
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
from vllm.compilation.backends import VllmBackend
@ -743,7 +694,7 @@ class CompilationConfig:
)
inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
self.level == CompilationLevel.PIECEWISE and self.use_inductor
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)

View File

@ -318,25 +318,6 @@ class VllmConfig:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
else:
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
assert self.compilation_config.level <= 3
# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
if (
self.compilation_config.custom_ops.count("none")
+ self.compilation_config.custom_ops.count("all")
== 0
):
if (
self.compilation_config.level > 0
and self.compilation_config.backend != "eager"
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")
# async tp is built on top of sequence parallelism
# and requires it to be enabled.

View File

@ -114,9 +114,7 @@ class CustomOp(nn.Module):
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
"Custom op %s was not registered, which means it won't appear\
in the op registry. It will be enabled/disabled based on the\
global settings.", # noqa: E501
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
cls.__name__,
)
return CustomOp.default_on()
@ -130,17 +128,19 @@ class CustomOp(nn.Module):
@staticmethod
def default_on() -> bool:
"""
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
'all', off by default if 'none'.
When PyTorch Inductor is used, 'none' is the default value,
otherwise 'all'.
On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_cached_compilation_config()
default_on = (
compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor
)
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
assert count_none + count_all == 1
return not count_none > 0 or count_all > 0
return default_on and not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.

View File

@ -274,6 +274,8 @@ class CpuPlatform(Platform):
"epilogue_fusion": True,
}
)
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION