mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[ROCm] Guard group quant RMS norm fusion patterns (#30239)
This commit is contained in:
parent
80433e225e
commit
eb1051fb95
@ -490,23 +490,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user