mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import operator
|
|
from typing import Iterable, Optional
|
|
|
|
from torch import fx
|
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
from torch._ops import OpOverload
|
|
|
|
|
|
def is_func(node: fx.Node, target) -> bool:
|
|
return node.op == "call_function" and node.target == target
|
|
|
|
|
|
# Returns the first auto_functionalized node with the given op (if it exists)
|
|
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
|
op: OpOverload) -> Optional[fx.Node]:
|
|
for node in nodes:
|
|
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
|
return node
|
|
return None
|
|
|
|
|
|
# Returns the first auto_functionalized node with the given op
|
|
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
|
node = find_auto_fn_maybe(nodes, op)
|
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
|
return node
|
|
|
|
|
|
# Returns the getitem node that extracts the idx-th element from node
|
|
# (if it exists)
|
|
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
|
for user in node.users:
|
|
if is_func(user, operator.getitem) and user.args[1] == idx:
|
|
return user
|
|
return None
|
|
|
|
|
|
# Returns the getitem node that extracts the idx-th element from node
|
|
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
|
ret = find_getitem_maybe(node, idx)
|
|
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
|
return ret
|