From eb1051fb95323493d6d950c03dabac8ee56cb33e Mon Sep 17 00:00:00 2001 From: "Ye (Charlotte) Qi" Date: Mon, 8 Dec 2025 06:44:48 -0800 Subject: [PATCH] [ROCm] Guard group quant RMS norm fusion patterns (#30239) --- vllm/compilation/fusion.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index de083a2e5e3c..a7e6a69e64c9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -490,23 +490,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: # Fuse fused_add_rms_norm + fp8 group quant - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(