mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 19:25:23 +08:00
Update deprecated type hinting in vllm/compilation (#18072)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
fc407a1425
commit
19324d660c
@ -74,7 +74,6 @@ exclude = [
|
|||||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||||
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
|
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
|
|
||||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
||||||
|
|||||||
@ -5,8 +5,9 @@ import dataclasses
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
from typing import Any, Callable, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -56,7 +57,7 @@ class CompilerManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, compilation_config: CompilationConfig):
|
def __init__(self, compilation_config: CompilationConfig):
|
||||||
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
|
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
|
||||||
self.is_cache_updated = False
|
self.is_cache_updated = False
|
||||||
self.compilation_config = compilation_config
|
self.compilation_config = compilation_config
|
||||||
self.compiler = make_compiler(compilation_config)
|
self.compiler = make_compiler(compilation_config)
|
||||||
@ -90,7 +91,7 @@ class CompilerManager:
|
|||||||
|
|
||||||
def load(self,
|
def load(self,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
graph_index: int,
|
graph_index: int,
|
||||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||||
@ -186,7 +187,7 @@ class SplitItem:
|
|||||||
|
|
||||||
|
|
||||||
def split_graph(graph: fx.GraphModule,
|
def split_graph(graph: fx.GraphModule,
|
||||||
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
|
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||||
# split graph by ops
|
# split graph by ops
|
||||||
subgraph_id = 0
|
subgraph_id = 0
|
||||||
node_to_subgraph_id = {}
|
node_to_subgraph_id = {}
|
||||||
@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.fx.GraphModule,
|
def __init__(self, module: torch.fx.GraphModule,
|
||||||
compile_submod_names: List[str], vllm_config: VllmConfig,
|
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||||
graph_pool, vllm_backend: "VllmBackend"):
|
graph_pool, vllm_backend: "VllmBackend"):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
@ -274,8 +275,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
return super().run(*fake_args)
|
return super().run(*fake_args)
|
||||||
|
|
||||||
def call_module(self, target: torch.fx.node.Target,
|
def call_module(self, target: torch.fx.node.Target,
|
||||||
args: Tuple[torch.fx.node.Argument,
|
args: tuple[torch.fx.node.Argument,
|
||||||
...], kwargs: Dict[str, Any]) -> Any:
|
...], kwargs: dict[str, Any]) -> Any:
|
||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
output = super().call_module(target, args, kwargs)
|
output = super().call_module(target, args, kwargs)
|
||||||
|
|
||||||
@ -326,12 +327,12 @@ class VllmBackend:
|
|||||||
graph: fx.GraphModule
|
graph: fx.GraphModule
|
||||||
# the stiching graph module for all the piecewise graphs
|
# the stiching graph module for all the piecewise graphs
|
||||||
split_gm: fx.GraphModule
|
split_gm: fx.GraphModule
|
||||||
piecewise_graphs: List[SplitItem]
|
piecewise_graphs: list[SplitItem]
|
||||||
returned_callable: Callable
|
returned_callable: Callable
|
||||||
# Inductor passes to run on the graph pre-defunctionalization
|
# Inductor passes to run on the graph pre-defunctionalization
|
||||||
post_grad_passes: Sequence[Callable]
|
post_grad_passes: Sequence[Callable]
|
||||||
sym_tensor_indices: List[int]
|
sym_tensor_indices: list[int]
|
||||||
input_buffers: List[torch.Tensor]
|
input_buffers: list[torch.Tensor]
|
||||||
compiler_manager: CompilerManager
|
compiler_manager: CompilerManager
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -573,14 +574,14 @@ class ConcreteSizeEntry:
|
|||||||
|
|
||||||
# for cudagraph debugging, track the input addresses
|
# for cudagraph debugging, track the input addresses
|
||||||
# during capture, and check if they are the same during replay
|
# during capture, and check if they are the same during replay
|
||||||
input_addresses: Optional[List[int]] = None
|
input_addresses: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class PiecewiseBackend:
|
class PiecewiseBackend:
|
||||||
|
|
||||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||||
graph_pool: Any, piecewise_compile_index: int,
|
graph_pool: Any, piecewise_compile_index: int,
|
||||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||||
compiled_graph_for_general_shape: Callable,
|
compiled_graph_for_general_shape: Callable,
|
||||||
vllm_backend: VllmBackend):
|
vllm_backend: VllmBackend):
|
||||||
"""
|
"""
|
||||||
@ -608,9 +609,9 @@ class PiecewiseBackend:
|
|||||||
self.is_last_graph = (
|
self.is_last_graph = (
|
||||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||||
|
|
||||||
self.compile_sizes: Set[int] = set(
|
self.compile_sizes: set[int] = set(
|
||||||
self.compilation_config.compile_sizes)
|
self.compilation_config.compile_sizes)
|
||||||
self.cudagraph_capture_sizes: Set[int] = set(
|
self.cudagraph_capture_sizes: set[int] = set(
|
||||||
self.compilation_config.cudagraph_capture_sizes
|
self.compilation_config.cudagraph_capture_sizes
|
||||||
) if self.compilation_config.use_cudagraph else set()
|
) if self.compilation_config.use_cudagraph else set()
|
||||||
|
|
||||||
@ -624,11 +625,11 @@ class PiecewiseBackend:
|
|||||||
|
|
||||||
# the entries for different shapes that we need to either
|
# the entries for different shapes that we need to either
|
||||||
# compile or capture cudagraph
|
# compile or capture cudagraph
|
||||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||||
|
|
||||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||||
# and updates during the compilation process, so we need to copy it
|
# and updates during the compilation process, so we need to copy it
|
||||||
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
runtime_shape=shape,
|
runtime_shape=shape,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import copy
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -48,11 +48,11 @@ class CompilerInterface:
|
|||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: Dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
"""
|
"""
|
||||||
Compile the graph with the given example inputs and compiler config,
|
Compile the graph with the given example inputs and compiler config,
|
||||||
with a runtime shape. If the `runtime_shape` is None, it means
|
with a runtime shape. If the `runtime_shape` is None, it means
|
||||||
@ -82,7 +82,7 @@ class CompilerInterface:
|
|||||||
def load(self,
|
def load(self,
|
||||||
handle: Any,
|
handle: Any,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
graph_index: int,
|
graph_index: int,
|
||||||
runtime_shape: Optional[int] = None) -> Callable:
|
runtime_shape: Optional[int] = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
@ -120,7 +120,7 @@ class AlwaysHitShapeEnv:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.guards: List[Any] = []
|
self.guards: list[Any] = []
|
||||||
|
|
||||||
def evaluate_guards_expression(self, *args, **kwargs):
|
def evaluate_guards_expression(self, *args, **kwargs):
|
||||||
return True
|
return True
|
||||||
@ -132,8 +132,8 @@ class AlwaysHitShapeEnv:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def get_inductor_factors() -> List[Any]:
|
def get_inductor_factors() -> list[Any]:
|
||||||
factors: List[Any] = []
|
factors: list[Any] = []
|
||||||
# summarize system state
|
# summarize system state
|
||||||
from torch._inductor.codecache import CacheBase
|
from torch._inductor.codecache import CacheBase
|
||||||
system_factors = CacheBase.get_system()
|
system_factors = CacheBase.get_system()
|
||||||
@ -169,11 +169,11 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: Dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
current_config.update(compiler_config)
|
current_config.update(compiler_config)
|
||||||
@ -201,7 +201,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
def load(self,
|
def load(self,
|
||||||
handle: Any,
|
handle: Any,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
graph_index: int,
|
graph_index: int,
|
||||||
runtime_shape: Optional[int] = None) -> Callable:
|
runtime_shape: Optional[int] = None) -> Callable:
|
||||||
assert isinstance(handle, tuple)
|
assert isinstance(handle, tuple)
|
||||||
@ -256,11 +256,11 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: Dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
@ -420,7 +420,7 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
def load(self,
|
def load(self,
|
||||||
handle: Any,
|
handle: Any,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
graph_index: int,
|
graph_index: int,
|
||||||
runtime_shape: Optional[int] = None) -> Callable:
|
runtime_shape: Optional[int] = None) -> Callable:
|
||||||
assert isinstance(handle, tuple)
|
assert isinstance(handle, tuple)
|
||||||
@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface):
|
|||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: List[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: Dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
runtime_shape: Optional[int] = None,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
# we don't need to compile the graph, just return the graph itself.
|
# we don't need to compile the graph, just return the graph itself.
|
||||||
# It does not support caching, return None for the handle.
|
# It does not support caching, return None for the handle.
|
||||||
return graph, None
|
return graph, None
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
from typing import Callable, Optional, TypeVar, Union, overload
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
|
|||||||
@overload
|
@overload
|
||||||
def support_torch_compile(
|
def support_torch_compile(
|
||||||
*,
|
*,
|
||||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
|
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
|
||||||
) -> Callable[[_T], _T]:
|
) -> Callable[[_T], _T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T:
|
|||||||
def support_torch_compile(
|
def support_torch_compile(
|
||||||
cls: Optional[_T] = None,
|
cls: Optional[_T] = None,
|
||||||
*,
|
*,
|
||||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
|
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
||||||
) -> Union[Callable[[_T], _T], _T]:
|
) -> Union[Callable[[_T], _T], _T]:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
@ -131,7 +131,7 @@ def support_torch_compile(
|
|||||||
|
|
||||||
def _support_torch_compile(
|
def _support_torch_compile(
|
||||||
cls: _T,
|
cls: _T,
|
||||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
|
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||||
) -> _T:
|
) -> _T:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
@ -27,7 +28,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
self.begin()
|
self.begin()
|
||||||
self.dump_graph(graph, "before_fix_functionalization")
|
self.dump_graph(graph, "before_fix_functionalization")
|
||||||
|
|
||||||
self.nodes_to_remove: List[torch.fx.Node] = []
|
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||||
count = 0
|
count = 0
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if not is_func(node, auto_functionalized):
|
if not is_func(node, auto_functionalized):
|
||||||
@ -117,8 +118,8 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
def defunctionalize(self,
|
def defunctionalize(self,
|
||||||
graph: torch.fx.Graph,
|
graph: torch.fx.Graph,
|
||||||
node: torch.fx.Node,
|
node: torch.fx.Node,
|
||||||
mutated_args: Dict[int, Union[torch.fx.Node, str]],
|
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||||
...]] = None):
|
...]] = None):
|
||||||
"""
|
"""
|
||||||
De-functionalize a node by replacing it with a call to the original.
|
De-functionalize a node by replacing it with a call to the original.
|
||||||
@ -130,7 +131,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
self._remove(node)
|
self._remove(node)
|
||||||
|
|
||||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||||
mutated_args: Dict[int,
|
mutated_args: dict[int,
|
||||||
Union[torch.fx.Node,
|
Union[torch.fx.Node,
|
||||||
str]]):
|
str]]):
|
||||||
"""
|
"""
|
||||||
@ -146,7 +147,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
user.replace_all_uses_with(arg)
|
user.replace_all_uses_with(arg)
|
||||||
self._remove(user)
|
self._remove(user)
|
||||||
|
|
||||||
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
|
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||||
"""
|
"""
|
||||||
Returns the operator.getitem users of the auto-functionalized node,
|
Returns the operator.getitem users of the auto-functionalized node,
|
||||||
indexed by the index they are getting.
|
indexed by the index they are getting.
|
||||||
@ -161,7 +162,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
def insert_defunctionalized(self,
|
def insert_defunctionalized(self,
|
||||||
graph: torch.fx.Graph,
|
graph: torch.fx.Graph,
|
||||||
node: torch.fx.Node,
|
node: torch.fx.Node,
|
||||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||||
...]] = None):
|
...]] = None):
|
||||||
"""
|
"""
|
||||||
Insert a new defunctionalized node into the graph before node.
|
Insert a new defunctionalized node into the graph before node.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
@ -57,7 +57,7 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
|||||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
||||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
||||||
|
|
||||||
QUANT_OPS: Dict[QuantKey, OpOverload] = {
|
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
||||||
kFp8DynamicTensorSym:
|
kFp8DynamicTensorSym:
|
||||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
||||||
@ -80,7 +80,7 @@ class FusedRMSQuantKey(NamedTuple):
|
|||||||
f"{'' if self.fused_add else 'out'} residual)")
|
f"{'' if self.fused_add else 'out'} residual)")
|
||||||
|
|
||||||
|
|
||||||
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
|
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
||||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||||
@ -101,7 +101,7 @@ class QuantMultiOutputMatch(MultiOutputMatch):
|
|||||||
self.QUANT_OP = quant_op # in-place quant op
|
self.QUANT_OP = quant_op # in-place quant op
|
||||||
self.FUSED_OP = fused_op # in-place fused quant op
|
self.FUSED_OP = fused_op # in-place fused quant op
|
||||||
|
|
||||||
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
|
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
|
||||||
int]],
|
int]],
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
@ -548,7 +548,7 @@ class FusionPass(VllmInductorPass):
|
|||||||
"FusionPass singleton instance already exists"
|
"FusionPass singleton instance already exists"
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.matches: List[MultiOutputMatch] = []
|
self.matches: list[MultiOutputMatch] = []
|
||||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
pass_name="fusion_pass")
|
pass_name="fusion_pass")
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Iterable, Optional
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from torch import fx
|
from torch import fx
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import fx
|
from torch import fx
|
||||||
@ -83,7 +83,7 @@ class InductorPass(CustomGraphPass):
|
|||||||
return hasher.hexdigest()
|
return hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash_dict(dict_: Dict[Any, Any]):
|
def hash_dict(dict_: dict[Any, Any]):
|
||||||
"""
|
"""
|
||||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||||
:return: A sha256 hash of the json rep of the dictionary.
|
:return: A sha256 hash of the json rep of the dictionary.
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import operator
|
import operator
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Iterable, List, Tuple
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from torch import fx
|
from torch import fx
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
@ -56,7 +56,7 @@ class MultiOutputMatch(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> List[fx.Node]:
|
def nodes(self) -> list[fx.Node]:
|
||||||
return self.match.nodes
|
return self.match.nodes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -87,7 +87,7 @@ class MultiOutputMatch(abc.ABC):
|
|||||||
return self.graph.inserting_after(last_node_in_match)
|
return self.graph.inserting_after(last_node_in_match)
|
||||||
|
|
||||||
def insert_getitems(self, tuple_node: fx.Node,
|
def insert_getitems(self, tuple_node: fx.Node,
|
||||||
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
|
indices: Iterable[int]) -> tuple[fx.Node, ...]:
|
||||||
"""
|
"""
|
||||||
Insert operator.getitem nodes to extract elements from a tuple node.
|
Insert operator.getitem nodes to extract elements from a tuple node.
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Iterable, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch import SymInt
|
from torch import SymInt
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -34,7 +32,7 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.passes: List[VllmInductorPass] = []
|
self.passes: list[VllmInductorPass] = []
|
||||||
|
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
shape = get_pass_context().runtime_shape
|
shape = get_pass_context().runtime_shape
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||||
|
|
||||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
tp = get_tp_group()
|
tp = get_tp_group()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||||
|
|
||||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
tp = get_tp_group()
|
tp = get_tp_group()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import sys
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from types import CodeType
|
from types import CodeType
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
|
|
||||||
self.compiled_callable = compiled_callable
|
self.compiled_callable = compiled_callable
|
||||||
self.original_code_object = self.__class__.forward.__code__
|
self.original_code_object = self.__class__.forward.__code__
|
||||||
self.compiled_codes: List[CodeType] = []
|
self.compiled_codes: list[CodeType] = []
|
||||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||||
|
|
||||||
# read the env var to determine whether to use the custom dispatcher
|
# read the env var to determine whether to use the custom dispatcher
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user