use combo kernel to fuse qk-norm and qk-rope (#26682)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng 2025-10-14 06:40:59 -07:00 committed by GitHub
parent e9f1b8c9e9
commit ca683a2a72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -513,6 +513,16 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if (
is_torch_equal_or_newer("2.9.0.dev")
and "combo_kernels" not in self.inductor_compile_config
and "benchmark_combo_kernel" not in self.inductor_compile_config
):
# use horizontal fusion, which is useful for fusing qk-norm and
# qk-rope when query and key have different shapes.
self.inductor_compile_config["combo_kernels"] = True
self.inductor_compile_config["benchmark_combo_kernel"] = True
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning(