[ROCm] Guard group quant RMS norm fusion patterns (#30239)

This commit is contained in:
Ye (Charlotte) Qi 2025-12-08 06:44:48 -08:00 committed by GitHub
parent 80433e225e
commit eb1051fb95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -490,23 +490,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# as the latter is a subset of the former in torch ops # as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant # Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern( # Only register group quant patterns on CUDA where the C++ op exists
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) if current_platform.is_cuda():
).register(self.patterns) FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant # Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern( RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns) ).register(self.patterns)
FusedAddRMSNormGroupQuantPattern( FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns) ).register(self.patterns)
# Fuse rms_norm + fp8 group quant # Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern( RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns) ).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant # Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(