[Misc] Refactor AllReduceFusionPass. Remove parameter (#20918)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov 2025-07-15 08:57:40 +02:00 committed by GitHub
parent d4d309409f
commit 37a7d5d74a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 10 deletions

View File

@ -132,9 +132,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
dtype=dtype, dtype=dtype,
seed=42) seed=42)
all_reduce_fusion_pass = AllReduceFusionPass( all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
vllm_config, vllm_config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num)
backend = TestBackend(all_reduce_fusion_pass) backend = TestBackend(all_reduce_fusion_pass)
model = test_model_cls(hidden_size) model = test_model_cls(hidden_size)

View File

@ -397,7 +397,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
class AllReduceFusionPass(VllmInductorPass): class AllReduceFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig, max_token_num: int): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)
self.disabled = True self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -429,7 +429,8 @@ class AllReduceFusionPass(VllmInductorPass):
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank, tp_rank=rank,
tp_size=self.tp_size, tp_size=self.tp_size,
max_token_num=max_token_num, max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
hidden_dim=self.hidden_dim, hidden_dim=self.hidden_dim,
group=self.group, group=self.group,
use_fp32_lamport=use_fp32_lamport, use_fp32_lamport=use_fp32_lamport,
@ -441,7 +442,8 @@ class AllReduceFusionPass(VllmInductorPass):
rank=rank, rank=rank,
world_size=self.tp_size, world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport, use_fp32_lamport=use_fp32_lamport,
max_token_num=max_token_num, max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
) )
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:

View File

@ -63,10 +63,7 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_attn_fusion: if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)] self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion: if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [ self.passes += [AllReduceFusionPass(config)]
AllReduceFusionPass(
config, self.pass_config.fi_allreduce_fusion_max_token_num)
]
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):