From d143271234454026454c5ee6a55fc516dd298dac Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Mon, 8 Dec 2025 14:43:47 +0800 Subject: [PATCH] [Bugfix] fix fuse_allreduce_rms when tp =1 (#30178) Signed-off-by: zjy0516 --- vllm/compilation/collective_fusion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 2717738dd7c29..57bd94c7e8ad6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.disabled = True self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") return self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="all_reduce_fusion_pass" ) if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) return self.hidden_dim = config.model_config.get_hidden_size() self.group = get_tp_group().device_group @@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.disabled = False def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False return compile_range.end <= self.max_token_num @VllmInductorPass.time_and_log