Update deprecated type hinting in vllm/profiler (#18057)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-13 12:34:34 +01:00 committed by GitHub
parent 6223dd8114
commit ff334ca1cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 24 deletions

View File

@ -84,7 +84,6 @@ exclude = [
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"] "vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] "vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]

View File

@ -3,7 +3,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union from typing import Any, Callable, Optional, TypeAlias, Union
import pandas as pd import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module,
class _ModuleTreeNode: class _ModuleTreeNode:
event: _ProfilerEvent event: _ProfilerEvent
parent: Optional['_ModuleTreeNode'] = None parent: Optional['_ModuleTreeNode'] = None
children: List['_ModuleTreeNode'] = field(default_factory=list) children: list['_ModuleTreeNode'] = field(default_factory=list)
trace: str = "" trace: str = ""
@property @property
@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
@dataclass @dataclass
class _StatsTreeNode: class _StatsTreeNode:
entry: StatsEntry entry: StatsEntry
children: List[StatsEntry] children: list[StatsEntry]
parent: Optional[StatsEntry] parent: Optional[StatsEntry]
@dataclass @dataclass
class LayerwiseProfileResults(profile): class LayerwiseProfileResults(profile):
_kineto_results: _ProfilerResult _kineto_results: _ProfilerResult
_kineto_event_correlation_map: Dict[int, _kineto_event_correlation_map: dict[int,
List[_KinetoEvent]] = field(init=False) list[_KinetoEvent]] = field(init=False)
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
_module_tree: List[_ModuleTreeNode] = field(init=False) _module_tree: list[_ModuleTreeNode] = field(init=False)
_model_stats_tree: List[_StatsTreeNode] = field(init=False) _model_stats_tree: list[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False) _summary_stats_tree: list[_StatsTreeNode] = field(init=False)
# profile metadata # profile metadata
num_running_seqs: Optional[int] = None num_running_seqs: Optional[int] = None
@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile):
self._build_module_tree() self._build_module_tree()
self._build_stats_trees() self._build_stats_trees()
def print_model_table(self, column_widths: Dict[str, int] = None): def print_model_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=60, _column_widths = dict(name=60,
cpu_time_us=12, cpu_time_us=12,
cuda_time_us=12, cuda_time_us=12,
@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile):
filtered_model_table, filtered_model_table,
indent_style=lambda indent: "|" + "-" * indent + " ")) indent_style=lambda indent: "|" + "-" * indent + " "))
def print_summary_table(self, column_widths: Dict[str, int] = None): def print_summary_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=80, _column_widths = dict(name=80,
cuda_time_us=12, cuda_time_us=12,
pct_cuda_time=12, pct_cuda_time=12,
@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile):
} }
@staticmethod @staticmethod
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
StatsEntry]], StatsEntry]],
indent_style: Union[Callable[[int], indent_style: Union[Callable[[int],
str], str],
@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile):
[self._cumulative_cuda_time(root) for root in self._module_tree]) [self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self): def _build_stats_trees(self):
summary_dict: Dict[str, _StatsTreeNode] = {} summary_dict: dict[str, _StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time() total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us): def pct_cuda_time(cuda_time_us):
@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile):
def build_summary_stats_tree_df( def build_summary_stats_tree_df(
node: _ModuleTreeNode, node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None, parent: Optional[_StatsTreeNode] = None,
summary_trace: Tuple[str] = ()): summary_trace: tuple[str] = ()):
if event_has_module(node.event): if event_has_module(node.event):
name = event_module_repr(node.event) name = event_module_repr(node.event)
@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile):
self._model_stats_tree.append(build_model_stats_tree_df(root)) self._model_stats_tree.append(build_model_stats_tree_df(root))
def _flatten_stats_tree( def _flatten_stats_tree(
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
entries: List[Tuple[int, StatsEntry]] = [] entries: list[tuple[int, StatsEntry]] = []
def df_traversal(node: _StatsTreeNode, depth=0): def df_traversal(node: _StatsTreeNode, depth=0):
entries.append((depth, node.entry)) entries.append((depth, node.entry))
@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile):
return entries return entries
def _convert_stats_tree_to_dict(self, def _convert_stats_tree_to_dict(self,
tree: List[_StatsTreeNode]) -> List[Dict]: tree: list[_StatsTreeNode]) -> list[dict]:
root_dicts: List[Dict] = [] root_dicts: list[dict] = []
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
curr_json_list.append({ curr_json_list.append({
"entry": asdict(node.entry), "entry": asdict(node.entry),
"children": [] "children": []

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
from typing import Callable, Dict, List, Type, Union from typing import Callable, Union
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
@ -30,14 +30,14 @@ def trim_string_back(string, width):
class TablePrinter: class TablePrinter:
def __init__(self, row_cls: Type[dataclasses.dataclass], def __init__(self, row_cls: type[dataclasses.dataclass],
column_widths: Dict[str, int]): column_widths: dict[str, int]):
self.row_cls = row_cls self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames) assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: List[dataclasses.dataclass]): def print_table(self, rows: list[dataclasses.dataclass]):
self._print_header() self._print_header()
self._print_line() self._print_line()
for row in rows: for row in rows: