From ff334ca1cd92c41cc79e9dead91de40b87601daf Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 13 May 2025 12:34:34 +0100 Subject: [PATCH] Update deprecated type hinting in `vllm/profiler` (#18057) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- pyproject.toml | 1 - vllm/profiler/layerwise_profile.py | 38 +++++++++++++++--------------- vllm/profiler/utils.py | 8 +++---- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 108fd7af9a3b..a3e75ec69d35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,6 @@ exclude = [ "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/plugins/**/*.py" = ["UP006", "UP035"] -"vllm/profiler/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 6351ef63da2b..6934d328a87e 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -3,7 +3,7 @@ import copy from collections import defaultdict from dataclasses import asdict, dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Any, Callable, Optional, TypeAlias, Union import pandas as pd from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult @@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module, class _ModuleTreeNode: event: _ProfilerEvent parent: Optional['_ModuleTreeNode'] = None - children: List['_ModuleTreeNode'] = field(default_factory=list) + children: list['_ModuleTreeNode'] = field(default_factory=list) trace: str = "" @property @@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] @dataclass class _StatsTreeNode: entry: StatsEntry - children: List[StatsEntry] + children: list[StatsEntry] parent: Optional[StatsEntry] @dataclass class LayerwiseProfileResults(profile): _kineto_results: _ProfilerResult - _kineto_event_correlation_map: Dict[int, - List[_KinetoEvent]] = field(init=False) - _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) - _module_tree: List[_ModuleTreeNode] = field(init=False) - _model_stats_tree: List[_StatsTreeNode] = field(init=False) - _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + _kineto_event_correlation_map: dict[int, + list[_KinetoEvent]] = field(init=False) + _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False) + _module_tree: list[_ModuleTreeNode] = field(init=False) + _model_stats_tree: list[_StatsTreeNode] = field(init=False) + _summary_stats_tree: list[_StatsTreeNode] = field(init=False) # profile metadata num_running_seqs: Optional[int] = None @@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile): self._build_module_tree() self._build_stats_trees() - def print_model_table(self, column_widths: Dict[str, int] = None): + def print_model_table(self, column_widths: dict[str, int] = None): _column_widths = dict(name=60, cpu_time_us=12, cuda_time_us=12, @@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile): filtered_model_table, indent_style=lambda indent: "|" + "-" * indent + " ")) - def print_summary_table(self, column_widths: Dict[str, int] = None): + def print_summary_table(self, column_widths: dict[str, int] = None): _column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, @@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile): } @staticmethod - def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + def _indent_row_names_based_on_depth(depths_rows: list[tuple[int, StatsEntry]], indent_style: Union[Callable[[int], str], @@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile): [self._cumulative_cuda_time(root) for root in self._module_tree]) def _build_stats_trees(self): - summary_dict: Dict[str, _StatsTreeNode] = {} + summary_dict: dict[str, _StatsTreeNode] = {} total_cuda_time = self._total_cuda_time() def pct_cuda_time(cuda_time_us): @@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile): def build_summary_stats_tree_df( node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None, - summary_trace: Tuple[str] = ()): + summary_trace: tuple[str] = ()): if event_has_module(node.event): name = event_module_repr(node.event) @@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile): self._model_stats_tree.append(build_model_stats_tree_df(root)) def _flatten_stats_tree( - self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: - entries: List[Tuple[int, StatsEntry]] = [] + self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]: + entries: list[tuple[int, StatsEntry]] = [] def df_traversal(node: _StatsTreeNode, depth=0): entries.append((depth, node.entry)) @@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile): return entries def _convert_stats_tree_to_dict(self, - tree: List[_StatsTreeNode]) -> List[Dict]: - root_dicts: List[Dict] = [] + tree: list[_StatsTreeNode]) -> list[dict]: + root_dicts: list[dict] = [] - def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): curr_json_list.append({ "entry": asdict(node.entry), "children": [] diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index 62b39f510703..b26fd4dd8c07 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Union from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata @@ -30,14 +30,14 @@ def trim_string_back(string, width): class TablePrinter: - def __init__(self, row_cls: Type[dataclasses.dataclass], - column_widths: Dict[str, int]): + def __init__(self, row_cls: type[dataclasses.dataclass], + column_widths: dict[str, int]): self.row_cls = row_cls self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.column_widths = column_widths assert set(self.column_widths.keys()) == set(self.fieldnames) - def print_table(self, rows: List[dataclasses.dataclass]): + def print_table(self, rows: list[dataclasses.dataclass]): self._print_header() self._print_line() for row in rows: