mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:45:01 +08:00
[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
b08025a83b
commit
c719c40540
@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "allreduce_in",
|
||||
2: "residual",
|
||||
3: "norm_out",
|
||||
4: "quant_out",
|
||||
5: "scale_out",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
|
||||
@ -75,7 +75,7 @@ def find_op_nodes(
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
if not op._schema.is_mutable:
|
||||
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user