diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index de083a2e5e3c..a7e6a69e64c9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -490,23 +490,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: # Fuse fused_add_rms_norm + fp8 group quant - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(