mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 09:24:28 +08:00
[Bugfix] fix fuse_allreduce_rms when tp =1 (#30178)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
c6df05ebb4
commit
d143271234
@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size <= 1:
|
||||
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
|
||||
return
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="all_reduce_fusion_pass"
|
||||
)
|
||||
if config.model_config is None:
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled for missing model_config."
|
||||
)
|
||||
return
|
||||
self.hidden_dim = config.model_config.get_hidden_size()
|
||||
self.group = get_tp_group().device_group
|
||||
@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
if self.disabled:
|
||||
logger.warning_once("AllReduce fusion pass is disabled.")
|
||||
return False
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user