diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7101857210ab..492e90f2a75f 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -132,9 +132,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, dtype=dtype, seed=42) - all_reduce_fusion_pass = AllReduceFusionPass( - vllm_config, vllm_config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num) + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) backend = TestBackend(all_reduce_fusion_pass) model = test_model_cls(hidden_size) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 97cb2995cb34..a8b00aaf0842 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -397,7 +397,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): class AllReduceFusionPass(VllmInductorPass): - def __init__(self, config: VllmConfig, max_token_num: int): + def __init__(self, config: VllmConfig): super().__init__(config) self.disabled = True self.tp_size = get_tensor_model_parallel_world_size() @@ -429,7 +429,8 @@ class AllReduceFusionPass(VllmInductorPass): flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_token_num, + max_token_num=config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -441,7 +442,8 @@ class AllReduceFusionPass(VllmInductorPass): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_token_num, + max_token_num=config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num, ) for epsilon in [1e-5, 1e-6]: diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 078188854f05..58216a1f0ed3 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -63,10 +63,7 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [ - AllReduceFusionPass( - config, self.pass_config.fi_allreduce_fusion_max_token_num) - ] + self.passes += [AllReduceFusionPass(config)] self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass):