[Bugfix] fix fuse_allreduce_rms when tp =1 (#30178)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-12-08 14:43:47 +08:00 committed by GitHub
parent c6df05ebb4
commit d143271234
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log