mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Misc] Refactor AllReduceFusionPass. Remove parameter (#20918)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
parent
d4d309409f
commit
37a7d5d74a
@ -132,9 +132,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=42)
|
seed=42)
|
||||||
|
|
||||||
all_reduce_fusion_pass = AllReduceFusionPass(
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
vllm_config, vllm_config.compilation_config.pass_config.
|
|
||||||
fi_allreduce_fusion_max_token_num)
|
|
||||||
backend = TestBackend(all_reduce_fusion_pass)
|
backend = TestBackend(all_reduce_fusion_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size)
|
model = test_model_cls(hidden_size)
|
||||||
|
|||||||
@ -397,7 +397,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
|||||||
|
|
||||||
class AllReduceFusionPass(VllmInductorPass):
|
class AllReduceFusionPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig, max_token_num: int):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -429,7 +429,8 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||||
tp_rank=rank,
|
tp_rank=rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
max_token_num=max_token_num,
|
max_token_num=config.compilation_config.pass_config.
|
||||||
|
fi_allreduce_fusion_max_token_num,
|
||||||
hidden_dim=self.hidden_dim,
|
hidden_dim=self.hidden_dim,
|
||||||
group=self.group,
|
group=self.group,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
@ -441,7 +442,8 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
max_token_num=max_token_num,
|
max_token_num=config.compilation_config.pass_config.
|
||||||
|
fi_allreduce_fusion_max_token_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
|
|||||||
@ -63,10 +63,7 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
if self.pass_config.enable_attn_fusion:
|
if self.pass_config.enable_attn_fusion:
|
||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
if self.pass_config.enable_fi_allreduce_fusion:
|
if self.pass_config.enable_fi_allreduce_fusion:
|
||||||
self.passes += [
|
self.passes += [AllReduceFusionPass(config)]
|
||||||
AllReduceFusionPass(
|
|
||||||
config, self.pass_config.fi_allreduce_fusion_max_token_num)
|
|
||||||
]
|
|
||||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||||
|
|
||||||
def add(self, pass_: InductorPass):
|
def add(self, pass_: InductorPass):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user