Mypy checking for vllm/compilation (#11496)

Signed-off-by: lucast2021 <lucast2021@headroyce.org>
Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
Lucas Tucker 2024-12-25 23:04:07 -06:00 committed by GitHub
parent 51a624bf02
commit dbeac95dbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 8 deletions

View File

@ -141,14 +141,14 @@ class AlwaysHitShapeEnv:
return "" return ""
def wrap_inductor(graph, def wrap_inductor(graph: fx.GraphModule,
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
graph_index: int = 0, graph_index: int = 0,
num_graphs: int = 1, num_graphs: int = 1,
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
use_inductor: bool = True): use_inductor: bool = True) -> Any:
if graph_index == 0: if graph_index == 0:
# before compiling the first graph, record the start time # before compiling the first graph, record the start time
global compilation_start_time global compilation_start_time
@ -209,7 +209,7 @@ def wrap_inductor(graph,
returns_tuple = graph_returns_tuple(graph) returns_tuple = graph_returns_tuple(graph)
# this is the graph we return to Dynamo to run # this is the graph we return to Dynamo to run
def compiled_graph(*args): def compiled_graph(*args) -> Optional[fx.CompiledFxGraph]:
# convert args to list # convert args to list
list_args = list(args) list_args = list(args)
graph_output = inductor_compiled_graph(list_args) graph_output = inductor_compiled_graph(list_args)
@ -247,7 +247,7 @@ def wrap_inductor(graph,
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return return
def _get_shape_env(): def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv() return AlwaysHitShapeEnv()
with patch(# for hijacking the hash of the compiled graph with patch(# for hijacking the hash of the compiled graph
@ -537,7 +537,7 @@ class VllmBackend:
example_inputs[x].clone() for x in self.sym_tensor_indices example_inputs[x].clone() for x in self.sym_tensor_indices
] ]
def copy_and_call(*args): def copy_and_call(*args) -> fx.GraphModule:
list_args = list(args) list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices): for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index] runtime_tensor = list_args[index]

View File

@ -7,6 +7,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload from torch._ops import OpOverload
from torch.fx import Node
from vllm.compilation.fx_utils import find_auto_fn from vllm.compilation.fx_utils import find_auto_fn
@ -97,7 +98,7 @@ class MultiOutputMatch(abc.ABC):
self.graph.call_function(operator.getitem, (tuple_node, idx)) self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices) for idx in indices)
def insert_auto_fn(self, op: OpOverload, kwargs): def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
""" """
Insert an auto_functionalized node with the given op and kwargs. Insert an auto_functionalized node with the given op and kwargs.
""" """

View File

@ -1,4 +1,4 @@
from typing import List from typing import Any, Dict, List
from torch import fx as fx from torch import fx as fx
@ -53,7 +53,7 @@ class PostGradPassManager:
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
def __getstate__(self): def __getstate__(self) -> Dict[str, List[Any]]:
""" """
Custom pickling for the pass manager, as some passes cannot be pickled. Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of Pickling occurs because the pass manager is set as the value of