mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
Update deprecated type hinting in vllm/profiler (#18057)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
6223dd8114
commit
ff334ca1cd
@ -84,7 +84,6 @@ exclude = [
|
||||
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
from typing import Any, Callable, Optional, TypeAlias, Union
|
||||
|
||||
import pandas as pd
|
||||
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
||||
@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module,
|
||||
class _ModuleTreeNode:
|
||||
event: _ProfilerEvent
|
||||
parent: Optional['_ModuleTreeNode'] = None
|
||||
children: List['_ModuleTreeNode'] = field(default_factory=list)
|
||||
children: list['_ModuleTreeNode'] = field(default_factory=list)
|
||||
trace: str = ""
|
||||
|
||||
@property
|
||||
@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
|
||||
@dataclass
|
||||
class _StatsTreeNode:
|
||||
entry: StatsEntry
|
||||
children: List[StatsEntry]
|
||||
children: list[StatsEntry]
|
||||
parent: Optional[StatsEntry]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerwiseProfileResults(profile):
|
||||
_kineto_results: _ProfilerResult
|
||||
_kineto_event_correlation_map: Dict[int,
|
||||
List[_KinetoEvent]] = field(init=False)
|
||||
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
|
||||
_module_tree: List[_ModuleTreeNode] = field(init=False)
|
||||
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||
_kineto_event_correlation_map: dict[int,
|
||||
list[_KinetoEvent]] = field(init=False)
|
||||
_event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
|
||||
_module_tree: list[_ModuleTreeNode] = field(init=False)
|
||||
_model_stats_tree: list[_StatsTreeNode] = field(init=False)
|
||||
_summary_stats_tree: list[_StatsTreeNode] = field(init=False)
|
||||
|
||||
# profile metadata
|
||||
num_running_seqs: Optional[int] = None
|
||||
@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile):
|
||||
self._build_module_tree()
|
||||
self._build_stats_trees()
|
||||
|
||||
def print_model_table(self, column_widths: Dict[str, int] = None):
|
||||
def print_model_table(self, column_widths: dict[str, int] = None):
|
||||
_column_widths = dict(name=60,
|
||||
cpu_time_us=12,
|
||||
cuda_time_us=12,
|
||||
@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile):
|
||||
filtered_model_table,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " "))
|
||||
|
||||
def print_summary_table(self, column_widths: Dict[str, int] = None):
|
||||
def print_summary_table(self, column_widths: dict[str, int] = None):
|
||||
_column_widths = dict(name=80,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
|
||||
def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
|
||||
StatsEntry]],
|
||||
indent_style: Union[Callable[[int],
|
||||
str],
|
||||
@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile):
|
||||
[self._cumulative_cuda_time(root) for root in self._module_tree])
|
||||
|
||||
def _build_stats_trees(self):
|
||||
summary_dict: Dict[str, _StatsTreeNode] = {}
|
||||
summary_dict: dict[str, _StatsTreeNode] = {}
|
||||
total_cuda_time = self._total_cuda_time()
|
||||
|
||||
def pct_cuda_time(cuda_time_us):
|
||||
@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile):
|
||||
def build_summary_stats_tree_df(
|
||||
node: _ModuleTreeNode,
|
||||
parent: Optional[_StatsTreeNode] = None,
|
||||
summary_trace: Tuple[str] = ()):
|
||||
summary_trace: tuple[str] = ()):
|
||||
|
||||
if event_has_module(node.event):
|
||||
name = event_module_repr(node.event)
|
||||
@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile):
|
||||
self._model_stats_tree.append(build_model_stats_tree_df(root))
|
||||
|
||||
def _flatten_stats_tree(
|
||||
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
|
||||
entries: List[Tuple[int, StatsEntry]] = []
|
||||
self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
|
||||
entries: list[tuple[int, StatsEntry]] = []
|
||||
|
||||
def df_traversal(node: _StatsTreeNode, depth=0):
|
||||
entries.append((depth, node.entry))
|
||||
@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile):
|
||||
return entries
|
||||
|
||||
def _convert_stats_tree_to_dict(self,
|
||||
tree: List[_StatsTreeNode]) -> List[Dict]:
|
||||
root_dicts: List[Dict] = []
|
||||
tree: list[_StatsTreeNode]) -> list[dict]:
|
||||
root_dicts: list[dict] = []
|
||||
|
||||
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
|
||||
def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
|
||||
curr_json_list.append({
|
||||
"entry": asdict(node.entry),
|
||||
"children": []
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Dict, List, Type, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
|
||||
|
||||
@ -30,14 +30,14 @@ def trim_string_back(string, width):
|
||||
|
||||
class TablePrinter:
|
||||
|
||||
def __init__(self, row_cls: Type[dataclasses.dataclass],
|
||||
column_widths: Dict[str, int]):
|
||||
def __init__(self, row_cls: type[dataclasses.dataclass],
|
||||
column_widths: dict[str, int]):
|
||||
self.row_cls = row_cls
|
||||
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
|
||||
self.column_widths = column_widths
|
||||
assert set(self.column_widths.keys()) == set(self.fieldnames)
|
||||
|
||||
def print_table(self, rows: List[dataclasses.dataclass]):
|
||||
def print_table(self, rows: list[dataclasses.dataclass]):
|
||||
self._print_header()
|
||||
self._print_line()
|
||||
for row in rows:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user