mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[Misc] Refactor AllReduceFusionPass. Remove parameter (#20918)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
parent
d4d309409f
commit
37a7d5d74a
@ -132,9 +132,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(
|
||||
vllm_config, vllm_config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num)
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
backend = TestBackend(all_reduce_fusion_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
@ -397,7 +397,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
|
||||
class AllReduceFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig, max_token_num: int):
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -429,7 +429,8 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=max_token_num,
|
||||
max_token_num=config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
group=self.group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
@ -441,7 +442,8 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
rank=rank,
|
||||
world_size=self.tp_size,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
max_token_num=max_token_num,
|
||||
max_token_num=config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num,
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
|
||||
@ -63,10 +63,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [
|
||||
AllReduceFusionPass(
|
||||
config, self.pass_config.fi_allreduce_fusion_max_token_num)
|
||||
]
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user