Update deprecated type hinting in vllm/compilation (#18072)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-13 16:32:48 +01:00 committed by GitHub
parent fc407a1425
commit 19324d660c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 70 additions and 69 deletions

View File

@ -74,7 +74,6 @@ exclude = [
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 # Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] "vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
"vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"] "vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
"vllm/distributed/**/*.py" = ["UP006", "UP035"] "vllm/distributed/**/*.py" = ["UP006", "UP035"]

View File

@ -5,8 +5,9 @@ import dataclasses
import os import os
import pprint import pprint
import time import time
from collections.abc import Sequence
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -56,7 +57,7 @@ class CompilerManager:
""" """
def __init__(self, compilation_config: CompilationConfig): def __init__(self, compilation_config: CompilationConfig):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config) self.compiler = make_compiler(compilation_config)
@ -90,7 +91,7 @@ class CompilerManager:
def load(self, def load(self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: Optional[int] = None) -> Optional[Callable]: runtime_shape: Optional[int] = None) -> Optional[Callable]:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache: if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
@ -186,7 +187,7 @@ class SplitItem:
def split_graph(graph: fx.GraphModule, def split_graph(graph: fx.GraphModule,
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]: ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops # split graph by ops
subgraph_id = 0 subgraph_id = 0
node_to_subgraph_id = {} node_to_subgraph_id = {}
@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
""" """
def __init__(self, module: torch.fx.GraphModule, def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], vllm_config: VllmConfig, compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"): graph_pool, vllm_backend: "VllmBackend"):
super().__init__(module) super().__init__(module)
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
@ -274,8 +275,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
return super().run(*fake_args) return super().run(*fake_args)
def call_module(self, target: torch.fx.node.Target, def call_module(self, target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument, args: tuple[torch.fx.node.Argument,
...], kwargs: Dict[str, Any]) -> Any: ...], kwargs: dict[str, Any]) -> Any:
assert isinstance(target, str) assert isinstance(target, str)
output = super().call_module(target, args, kwargs) output = super().call_module(target, args, kwargs)
@ -326,12 +327,12 @@ class VllmBackend:
graph: fx.GraphModule graph: fx.GraphModule
# the stiching graph module for all the piecewise graphs # the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule split_gm: fx.GraphModule
piecewise_graphs: List[SplitItem] piecewise_graphs: list[SplitItem]
returned_callable: Callable returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization # Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable] post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int] sym_tensor_indices: list[int]
input_buffers: List[torch.Tensor] input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager compiler_manager: CompilerManager
def __init__( def __init__(
@ -573,14 +574,14 @@ class ConcreteSizeEntry:
# for cudagraph debugging, track the input addresses # for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay # during capture, and check if they are the same during replay
input_addresses: Optional[List[int]] = None input_addresses: Optional[list[int]] = None
class PiecewiseBackend: class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int, graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int], total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable, compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend): vllm_backend: VllmBackend):
""" """
@ -608,9 +609,9 @@ class PiecewiseBackend:
self.is_last_graph = ( self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1) piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: Set[int] = set( self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes) self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: Set[int] = set( self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set() ) if self.compilation_config.use_cudagraph else set()
@ -624,11 +625,11 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either # the entries for different shapes that we need to either
# compile or capture cudagraph # compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile, # to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it # and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,

View File

@ -4,7 +4,7 @@ import copy
import hashlib import hashlib
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -48,11 +48,11 @@ class CompilerInterface:
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
compiler_config: Dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
key: Optional[str] = None, key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Optional[Callable], Optional[Any]]:
""" """
Compile the graph with the given example inputs and compiler config, Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means with a runtime shape. If the `runtime_shape` is None, it means
@ -82,7 +82,7 @@ class CompilerInterface:
def load(self, def load(self,
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: Optional[int] = None) -> Callable: runtime_shape: Optional[int] = None) -> Callable:
""" """
@ -120,7 +120,7 @@ class AlwaysHitShapeEnv:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.guards: List[Any] = [] self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs): def evaluate_guards_expression(self, *args, **kwargs):
return True return True
@ -132,8 +132,8 @@ class AlwaysHitShapeEnv:
return "" return ""
def get_inductor_factors() -> List[Any]: def get_inductor_factors() -> list[Any]:
factors: List[Any] = [] factors: list[Any] = []
# summarize system state # summarize system state
from torch._inductor.codecache import CacheBase from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system() system_factors = CacheBase.get_system()
@ -169,11 +169,11 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
compiler_config: Dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
key: Optional[str] = None, key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Optional[Callable], Optional[Any]]:
current_config = {} current_config = {}
if compiler_config is not None: if compiler_config is not None:
current_config.update(compiler_config) current_config.update(compiler_config)
@ -201,7 +201,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def load(self, def load(self,
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: Optional[int] = None) -> Callable: runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
@ -256,11 +256,11 @@ class InductorAdaptor(CompilerInterface):
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
compiler_config: Dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
key: Optional[str] = None, key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Optional[Callable], Optional[Any]]:
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
current_config = {} current_config = {}
if compiler_config is not None: if compiler_config is not None:
@ -420,7 +420,7 @@ class InductorAdaptor(CompilerInterface):
def load(self, def load(self,
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: Optional[int] = None) -> Callable: runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface):
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: List[Any], example_inputs: list[Any],
compiler_config: Dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None, runtime_shape: Optional[int] = None,
key: Optional[str] = None, key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself. # we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle. # It does not support caching, return None for the handle.
return graph, None return graph, None

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import inspect import inspect
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
@overload @overload
def support_torch_compile( def support_torch_compile(
*, *,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
) -> Callable[[_T], _T]: ) -> Callable[[_T], _T]:
... ...
@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T:
def support_torch_compile( def support_torch_compile(
cls: Optional[_T] = None, cls: Optional[_T] = None,
*, *,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]: ) -> Union[Callable[[_T], _T], _T]:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
@ -131,7 +131,7 @@ def support_torch_compile(
def _support_torch_compile( def _support_torch_compile(
cls: _T, cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]], dynamic_arg_dims: dict[str, Union[int, list[int]]],
) -> _T: ) -> _T:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import operator import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -27,7 +28,7 @@ class FixFunctionalizationPass(VllmInductorPass):
self.begin() self.begin()
self.dump_graph(graph, "before_fix_functionalization") self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: List[torch.fx.Node] = [] self.nodes_to_remove: list[torch.fx.Node] = []
count = 0 count = 0
for node in graph.nodes: for node in graph.nodes:
if not is_func(node, auto_functionalized): if not is_func(node, auto_functionalized):
@ -117,8 +118,8 @@ class FixFunctionalizationPass(VllmInductorPass):
def defunctionalize(self, def defunctionalize(self,
graph: torch.fx.Graph, graph: torch.fx.Graph,
node: torch.fx.Node, node: torch.fx.Node,
mutated_args: Dict[int, Union[torch.fx.Node, str]], mutated_args: dict[int, Union[torch.fx.Node, str]],
args: Optional[Tuple[Union[torch.fx.Node, str], args: Optional[tuple[Union[torch.fx.Node, str],
...]] = None): ...]] = None):
""" """
De-functionalize a node by replacing it with a call to the original. De-functionalize a node by replacing it with a call to the original.
@ -130,7 +131,7 @@ class FixFunctionalizationPass(VllmInductorPass):
self._remove(node) self._remove(node)
def replace_users_with_mutated_args(self, node: torch.fx.Node, def replace_users_with_mutated_args(self, node: torch.fx.Node,
mutated_args: Dict[int, mutated_args: dict[int,
Union[torch.fx.Node, Union[torch.fx.Node,
str]]): str]]):
""" """
@ -146,7 +147,7 @@ class FixFunctionalizationPass(VllmInductorPass):
user.replace_all_uses_with(arg) user.replace_all_uses_with(arg)
self._remove(user) self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
""" """
Returns the operator.getitem users of the auto-functionalized node, Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting. indexed by the index they are getting.
@ -161,7 +162,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def insert_defunctionalized(self, def insert_defunctionalized(self,
graph: torch.fx.Graph, graph: torch.fx.Graph,
node: torch.fx.Node, node: torch.fx.Node,
args: Optional[Tuple[Union[torch.fx.Node, str], args: Optional[tuple[Union[torch.fx.Node, str],
...]] = None): ...]] = None):
""" """
Insert a new defunctionalized node into the graph before node. Insert a new defunctionalized node into the graph before node.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple from typing import Callable, NamedTuple, Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@ -57,7 +57,7 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
QUANT_OPS: Dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8DynamicTensorSym: kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
@ -80,7 +80,7 @@ class FusedRMSQuantKey(NamedTuple):
f"{'' if self.fused_add else 'out'} residual)") f"{'' if self.fused_add else 'out'} residual)")
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False): FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8StaticTensorSym, True): FusedRMSQuantKey(kFp8StaticTensorSym, True):
@ -101,7 +101,7 @@ class QuantMultiOutputMatch(MultiOutputMatch):
self.QUANT_OP = quant_op # in-place quant op self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
int]], int]],
**kwargs): **kwargs):
""" """
@ -548,7 +548,7 @@ class FusionPass(VllmInductorPass):
"FusionPass singleton instance already exists" "FusionPass singleton instance already exists"
super().__init__(config) super().__init__(config)
self.matches: List[MultiOutputMatch] = [] self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass") pass_name="fusion_pass")

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import operator import operator
from typing import Iterable, Optional from collections.abc import Iterable
from typing import Optional
from torch import fx from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized

View File

@ -5,7 +5,7 @@ import inspect
import json import json
import types import types
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from torch import fx from torch import fx
@ -83,7 +83,7 @@ class InductorPass(CustomGraphPass):
return hasher.hexdigest() return hasher.hexdigest()
@staticmethod @staticmethod
def hash_dict(dict_: Dict[Any, Any]): def hash_dict(dict_: dict[Any, Any]):
""" """
Utility method to hash a dictionary, can alternatively be used for uuid. Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary. :return: A sha256 hash of the json rep of the dictionary.

View File

@ -3,7 +3,7 @@
import abc import abc
import operator import operator
from abc import abstractmethod from abc import abstractmethod
from typing import Iterable, List, Tuple from collections.abc import Iterable
from torch import fx from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -56,7 +56,7 @@ class MultiOutputMatch(abc.ABC):
raise NotImplementedError raise NotImplementedError
@property @property
def nodes(self) -> List[fx.Node]: def nodes(self) -> list[fx.Node]:
return self.match.nodes return self.match.nodes
@property @property
@ -87,7 +87,7 @@ class MultiOutputMatch(abc.ABC):
return self.graph.inserting_after(last_node_in_match) return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node, def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> Tuple[fx.Node, ...]: indices: Iterable[int]) -> tuple[fx.Node, ...]:
""" """
Insert operator.getitem nodes to extract elements from a tuple node. Insert operator.getitem nodes to extract elements from a tuple node.

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Union from collections.abc import Iterable
from typing import Union
import torch.fx import torch.fx
from torch import SymInt from torch import SymInt

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
from torch import fx as fx from torch import fx as fx
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -34,7 +32,7 @@ class PostGradPassManager(CustomGraphPass):
""" """
def __init__(self): def __init__(self):
self.passes: List[VllmInductorPass] = [] self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape shape = get_pass_context().runtime_shape

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1) all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized( rmsnorm = torch.ops.higher_order.auto_functionalized(
@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group() tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default( reduce_scatter = torch.ops.vllm.reduce_scatter.default(
@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1) all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized( rmsnorm = torch.ops.higher_order.auto_functionalized(
@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group() tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default( reduce_scatter = torch.ops.vllm.reduce_scatter.default(

View File

@ -5,7 +5,7 @@ import sys
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from types import CodeType from types import CodeType
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
@ -48,7 +48,7 @@ class TorchCompileWrapperWithCustomDispatcher:
self.compiled_callable = compiled_callable self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__ self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = [] self.compiled_codes: list[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher # read the env var to determine whether to use the custom dispatcher