diff --git a/pyproject.toml b/pyproject.toml index 408841845bde..6f5c560e800f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,6 @@ exclude = [ # Python 3.8 typing. TODO: Remove these excludes after v1.0.0 "vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/compilation/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/device_allocator/**/*.py" = ["UP006", "UP035"] "vllm/distributed/**/*.py" = ["UP006", "UP035"] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c2e8c726c943..0c1381a565c1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -5,8 +5,9 @@ import dataclasses import os import pprint import time +from collections.abc import Sequence from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Callable, Optional from unittest.mock import patch import torch @@ -56,7 +57,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() + self.cache: dict[tuple[Optional[int], int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -90,7 +91,7 @@ class CompilerManager: def load(self, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Optional[Callable]: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: @@ -186,7 +187,7 @@ class SplitItem: def split_graph(graph: fx.GraphModule, - ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]: + ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} @@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): """ def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: List[str], vllm_config: VllmConfig, + compile_submod_names: list[str], vllm_config: VllmConfig, graph_pool, vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode @@ -274,8 +275,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): return super().run(*fake_args) def call_module(self, target: torch.fx.node.Target, - args: Tuple[torch.fx.node.Argument, - ...], kwargs: Dict[str, Any]) -> Any: + args: tuple[torch.fx.node.Argument, + ...], kwargs: dict[str, Any]) -> Any: assert isinstance(target, str) output = super().call_module(target, args, kwargs) @@ -326,12 +327,12 @@ class VllmBackend: graph: fx.GraphModule # the stiching graph module for all the piecewise graphs split_gm: fx.GraphModule - piecewise_graphs: List[SplitItem] + piecewise_graphs: list[SplitItem] returned_callable: Callable # Inductor passes to run on the graph pre-defunctionalization post_grad_passes: Sequence[Callable] - sym_tensor_indices: List[int] - input_buffers: List[torch.Tensor] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] compiler_manager: CompilerManager def __init__( @@ -573,14 +574,14 @@ class ConcreteSizeEntry: # for cudagraph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[List[int]] = None + input_addresses: Optional[list[int]] = None class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: List[int], + total_piecewise_compiles: int, sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ @@ -608,9 +609,9 @@ class PiecewiseBackend: self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) - self.compile_sizes: Set[int] = set( + self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: Set[int] = set( + self.cudagraph_capture_sizes: set[int] = set( self.compilation_config.cudagraph_capture_sizes ) if self.compilation_config.use_cudagraph else set() @@ -624,11 +625,11 @@ class PiecewiseBackend: # the entries for different shapes that we need to either # compile or capture cudagraph - self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 423581784f7a..89a131e8ea24 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -4,7 +4,7 @@ import copy import hashlib import os from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional from unittest.mock import patch import torch @@ -48,11 +48,11 @@ class CompilerInterface: def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], + example_inputs: list[Any], + compiler_config: dict[str, Any], runtime_shape: Optional[int] = None, key: Optional[str] = None, - ) -> Tuple[Optional[Callable], Optional[Any]]: + ) -> tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, with a runtime shape. If the `runtime_shape` is None, it means @@ -82,7 +82,7 @@ class CompilerInterface: def load(self, handle: Any, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Callable: """ @@ -120,7 +120,7 @@ class AlwaysHitShapeEnv: """ def __init__(self) -> None: - self.guards: List[Any] = [] + self.guards: list[Any] = [] def evaluate_guards_expression(self, *args, **kwargs): return True @@ -132,8 +132,8 @@ class AlwaysHitShapeEnv: return "" -def get_inductor_factors() -> List[Any]: - factors: List[Any] = [] +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] # summarize system state from torch._inductor.codecache import CacheBase system_factors = CacheBase.get_system() @@ -169,11 +169,11 @@ class InductorStandaloneAdaptor(CompilerInterface): def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], + example_inputs: list[Any], + compiler_config: dict[str, Any], runtime_shape: Optional[int] = None, key: Optional[str] = None, - ) -> Tuple[Optional[Callable], Optional[Any]]: + ) -> tuple[Optional[Callable], Optional[Any]]: current_config = {} if compiler_config is not None: current_config.update(compiler_config) @@ -201,7 +201,7 @@ class InductorStandaloneAdaptor(CompilerInterface): def load(self, handle: Any, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Callable: assert isinstance(handle, tuple) @@ -256,11 +256,11 @@ class InductorAdaptor(CompilerInterface): def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], + example_inputs: list[Any], + compiler_config: dict[str, Any], runtime_shape: Optional[int] = None, key: Optional[str] = None, - ) -> Tuple[Optional[Callable], Optional[Any]]: + ) -> tuple[Optional[Callable], Optional[Any]]: from torch._inductor.compile_fx import compile_fx current_config = {} if compiler_config is not None: @@ -420,7 +420,7 @@ class InductorAdaptor(CompilerInterface): def load(self, handle: Any, graph: fx.GraphModule, - example_inputs: List[Any], + example_inputs: list[Any], graph_index: int, runtime_shape: Optional[int] = None) -> Callable: assert isinstance(handle, tuple) @@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface): def compile( self, graph: fx.GraphModule, - example_inputs: List[Any], - compiler_config: Dict[str, Any], + example_inputs: list[Any], + compiler_config: dict[str, Any], runtime_shape: Optional[int] = None, key: Optional[str] = None, - ) -> Tuple[Optional[Callable], Optional[Any]]: + ) -> tuple[Optional[Callable], Optional[Any]]: # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. return graph, None diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 20afe6967df3..f02994c55527 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Callable, Dict, List, Optional, TypeVar, Union, overload +from typing import Callable, Optional, TypeVar, Union, overload from unittest.mock import patch import torch @@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=type[nn.Module]) @overload def support_torch_compile( *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], ) -> Callable[[_T], _T]: ... @@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T: def support_torch_compile( cls: Optional[_T] = None, *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -131,7 +131,7 @@ def support_torch_compile( def _support_torch_compile( cls: _T, - dynamic_arg_dims: Dict[str, Union[int, List[int]]], + dynamic_arg_dims: dict[str, Union[int, list[int]]], ) -> _T: """ A decorator to add support for compiling the forward method of a class. diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 7f3120660329..70f3b8b6df94 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import operator -from typing import Dict, Iterable, List, Optional, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -27,7 +28,7 @@ class FixFunctionalizationPass(VllmInductorPass): self.begin() self.dump_graph(graph, "before_fix_functionalization") - self.nodes_to_remove: List[torch.fx.Node] = [] + self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 for node in graph.nodes: if not is_func(node, auto_functionalized): @@ -117,8 +118,8 @@ class FixFunctionalizationPass(VllmInductorPass): def defunctionalize(self, graph: torch.fx.Graph, node: torch.fx.Node, - mutated_args: Dict[int, Union[torch.fx.Node, str]], - args: Optional[Tuple[Union[torch.fx.Node, str], + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None): """ De-functionalize a node by replacing it with a call to the original. @@ -130,7 +131,7 @@ class FixFunctionalizationPass(VllmInductorPass): self._remove(node) def replace_users_with_mutated_args(self, node: torch.fx.Node, - mutated_args: Dict[int, + mutated_args: dict[int, Union[torch.fx.Node, str]]): """ @@ -146,7 +147,7 @@ class FixFunctionalizationPass(VllmInductorPass): user.replace_all_uses_with(arg) self._remove(user) - def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: """ Returns the operator.getitem users of the auto-functionalized node, indexed by the index they are getting. @@ -161,7 +162,7 @@ class FixFunctionalizationPass(VllmInductorPass): def insert_defunctionalized(self, graph: torch.fx.Graph, node: torch.fx.Node, - args: Optional[Tuple[Union[torch.fx.Node, str], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None): """ Insert a new defunctionalized node into the graph before node. diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f32fdb03f8b..618b2fe94d3a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -57,7 +57,7 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) -QUANT_OPS: Dict[QuantKey, OpOverload] = { +QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa @@ -80,7 +80,7 @@ class FusedRMSQuantKey(NamedTuple): f"{'' if self.fused_add else 'out'} residual)") -FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { +FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FusedRMSQuantKey(kFp8StaticTensorSym, False): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa FusedRMSQuantKey(kFp8StaticTensorSym, True): @@ -101,7 +101,7 @@ class QuantMultiOutputMatch(MultiOutputMatch): self.QUANT_OP = quant_op # in-place quant op self.FUSED_OP = fused_op # in-place fused quant op - def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, int]], **kwargs): """ @@ -548,7 +548,7 @@ class FusionPass(VllmInductorPass): "FusionPass singleton instance already exists" super().__init__(config) - self.matches: List[MultiOutputMatch] = [] + self.matches: list[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index f9427e48ac31..b9eeb0c8d2af 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import operator -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 4f5c82776839..a9359fe1e117 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -5,7 +5,7 @@ import inspect import json import types from contextlib import contextmanager -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch import fx @@ -83,7 +83,7 @@ class InductorPass(CustomGraphPass): return hasher.hexdigest() @staticmethod - def hash_dict(dict_: Dict[Any, Any]): + def hash_dict(dict_: dict[Any, Any]): """ Utility method to hash a dictionary, can alternatively be used for uuid. :return: A sha256 hash of the json rep of the dictionary. diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py index e6f6a60b2595..cef19f9257ed 100644 --- a/vllm/compilation/multi_output_match.py +++ b/vllm/compilation/multi_output_match.py @@ -3,7 +3,7 @@ import abc import operator from abc import abstractmethod -from typing import Iterable, List, Tuple +from collections.abc import Iterable from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -56,7 +56,7 @@ class MultiOutputMatch(abc.ABC): raise NotImplementedError @property - def nodes(self) -> List[fx.Node]: + def nodes(self) -> list[fx.Node]: return self.match.nodes @property @@ -87,7 +87,7 @@ class MultiOutputMatch(abc.ABC): return self.graph.inserting_after(last_node_in_match) def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> Tuple[fx.Node, ...]: + indices: Iterable[int]) -> tuple[fx.Node, ...]: """ Insert operator.getitem nodes to extract elements from a tuple node. diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 19127e933ec4..13e4cd73f8ce 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Union +from collections.abc import Iterable +from typing import Union import torch.fx from torch import SymInt diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index b1646914c7ed..f4d3fd9b457f 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - from torch import fx as fx from vllm.config import VllmConfig @@ -34,7 +32,7 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: List[VllmInductorPass] = [] + self.passes: list[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): shape = get_pass_context().runtime_shape diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 95db63d34f7e..f0476bfcb65a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Optional import torch import torch._inductor.pattern_matcher as pm @@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( @@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: tp = get_tp_group() tp_size = get_tensor_model_parallel_world_size() reduce_scatter = torch.ops.vllm.reduce_scatter.default( @@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( @@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: tp = get_tp_group() tp_size = get_tensor_model_parallel_world_size() reduce_scatter = torch.ops.vllm.reduce_scatter.default( diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index a8a283ddd8c0..1a8211f0ab7c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -5,7 +5,7 @@ import sys from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List, Optional +from typing import Callable, Optional import torch @@ -48,7 +48,7 @@ class TorchCompileWrapperWithCustomDispatcher: self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: List[CodeType] = [] + self.compiled_codes: list[CodeType] = [] torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) # read the env var to determine whether to use the custom dispatcher