[ROCm] Guard group quant RMS norm fusion patterns (#30239)

This commit is contained in:
Ye (Charlotte) Qi 2025-12-08 06:44:48 -08:00 committed by GitHub
parent 80433e225e
commit eb1051fb95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -490,23 +490,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(