Update some more deprecated type hinting (#17998)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-13 00:49:33 +01:00 committed by GitHub
parent acee8f48aa
commit 9d7ea9dbbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 73 deletions

View File

@ -79,7 +79,9 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"] "vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"] "vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/profiler/**/*.py" = ["UP006", "UP035"] "vllm/profiler/**/*.py" = ["UP006", "UP035"]

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, Type
import torch.nn as nn import torch.nn as nn
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
@ -138,7 +136,7 @@ class CustomOp(nn.Module):
# Examples: # Examples:
# - MyOp.enabled() # - MyOp.enabled()
# - op_registry["my_op"].enabled() # - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {} op_registry: dict[str, type['CustomOp']] = {}
# Decorator to register custom ops. # Decorator to register custom ops.
@classmethod @classmethod

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import Any, List from typing import Any
import llguidance import llguidance
import llguidance.hf import llguidance.hf
@ -62,7 +62,7 @@ class GuidanceLogitsProcessor:
def __call__( def __call__(
self, self,
input_ids: List[int], input_ids: list[int],
scores: torch.Tensor, scores: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# we initialize the guidance model here # we initialize the guidance model here

View File

@ -1,16 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, TypedDict, Union from typing import Optional, TypedDict, Union
from pydantic import BaseModel from pydantic import BaseModel
# These classes are deprecated, see SamplingParams # These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False): class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str] guided_json: Union[dict, BaseModel, str]
guided_regex: str guided_regex: str
guided_choice: List[str] guided_choice: list[str]
guided_grammar: str guided_grammar: str
guided_decoding_backend: str guided_decoding_backend: str
guided_whitespace_pattern: str guided_whitespace_pattern: str
@ -20,9 +20,9 @@ class LLMGuidedOptions(TypedDict, total=False):
@dataclass @dataclass
class GuidedDecodingRequest: class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor.""" """One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[Dict, BaseModel, str]] = None guided_json: Optional[Union[dict, BaseModel, str]] = None
guided_regex: Optional[str] = None guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None guided_choice: Optional[list[str]] = None
guided_grammar: Optional[str] = None guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None guided_whitespace_pattern: Optional[str] = None

View File

@ -6,7 +6,7 @@ import os
from enum import Enum from enum import Enum
from json import dumps as json_dumps from json import dumps as json_dumps
from re import escape as regex_escape from re import escape as regex_escape
from typing import Optional, Tuple, Union from typing import Optional, Union
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -111,7 +111,7 @@ def get_local_outlines_guided_decoding_logits_processor(
def _get_guide_and_mode( def _get_guide_and_mode(
guided_params: GuidedDecodingParams guided_params: GuidedDecodingParams
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
if guided_params.json: if guided_params.json:
if isinstance(guided_params.json, dict): if isinstance(guided_params.json, dict):
# turn dict into hashable string # turn dict into hashable string

View File

@ -19,7 +19,7 @@ import copy
import json import json
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -53,10 +53,10 @@ class BaseLogitsProcessor:
self._guide: Guide = guide self._guide: Guide = guide
self._reasoner: Optional[ReasoningParser] = reasoner self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int, self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
def __call__(self, input_ids: List[int], def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
@ -160,7 +160,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel], def __init__(self, schema: Union[str, dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser]): reasoner: Optional[ReasoningParser]):
@ -181,7 +181,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
""" """
if isinstance(schema, type(BaseModel)): if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema()) schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict): elif isinstance(schema, dict):
schema_str = json.dumps(schema) schema_str = json.dumps(schema)
elif isinstance(schema, str): elif isinstance(schema, str):
schema_str = schema schema_str = schema
@ -252,11 +252,11 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
return string return string
def change_decoder( def change_decoder(
decoder: Callable[[List[int]], decoder: Callable[[list[int]],
str]) -> Callable[[List[int]], List[str]]: str]) -> Callable[[list[int]], list[str]]:
"""Sync vLLM's decoder with the outlines by returning list.""" """Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]: def new_decoder(inp_tokens: list[int]) -> list[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)): and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0] inp_tokens = inp_tokens[0]

View File

@ -6,7 +6,7 @@ from __future__ import annotations
import json import json
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, List from typing import TYPE_CHECKING, Any
import torch import torch
@ -273,7 +273,7 @@ class GrammarConfig:
return re.sub(r'(["\\])', r'\\\1', s) return re.sub(r'(["\\])', r'\\\1', s)
@staticmethod @staticmethod
def choice_as_grammar(choice: List[str] | None) -> str: def choice_as_grammar(choice: list[str] | None) -> str:
if choice is None: if choice is None:
raise ValueError("Choice is not set") raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice) escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple from typing import Any
import torch import torch
@ -23,9 +23,9 @@ class PoolingMetadata:
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], PoolingParams]], seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: Dict[int, Any], # Specific data related to sequences seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: List[int], prompt_lens: list[int],
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_data = seq_data self.seq_data = seq_data

View File

@ -2,7 +2,7 @@
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch import torch
@ -25,10 +25,10 @@ class SequenceGroupToSample:
# |-- query_len ---| # |-- query_len ---|
# Sequence ids for the sequence group in a previous step. # Sequence ids for the sequence group in a previous step.
seq_ids: List[int] seq_ids: list[int]
sampling_params: SamplingParams sampling_params: SamplingParams
# seq_id -> sequence data. # seq_id -> sequence data.
seq_data: Dict[int, SequenceData] seq_data: dict[int, SequenceData]
# The length of the sequence (all tokens seen in the past + new token to # The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode # compute attention) of the sequence group. None if it is in a decode
# stage. # stage.
@ -44,9 +44,9 @@ class SequenceGroupToSample:
is_prompt: bool is_prompt: bool
# Query token indices from logits. to compute prompt logprob. Empty if # Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required. # prompt logprob is not required.
prompt_logprob_indices: List[int] prompt_logprob_indices: list[int]
# Sample token indices from logits. Empty if sampling is not required. # Sample token indices from logits. Empty if sampling is not required.
sample_indices: List[int] sample_indices: list[int]
@property @property
def do_sample(self): def do_sample(self):
@ -78,7 +78,7 @@ class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations""" """Used to cache SamplingMetadata objects between scheduler iterations"""
def __init__(self): def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
def get_cached_seq_group_to_sample(self, num_seqs): def get_cached_seq_group_to_sample(self, num_seqs):
if num_seqs not in self._seq_group_to_sample_cache: if num_seqs not in self._seq_group_to_sample_cache:
@ -130,9 +130,9 @@ class SamplingMetadata:
def __init__( def __init__(
self, self,
seq_groups: List[SequenceGroupToSample], seq_groups: list[SequenceGroupToSample],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor], categorized_sample_indices: dict[SamplingType, torch.Tensor],
num_prompts: int, num_prompts: int,
skip_sampler_cpu_output: bool = False, skip_sampler_cpu_output: bool = False,
reuse_sampling_tensors: bool = False, reuse_sampling_tensors: bool = False,
@ -146,12 +146,12 @@ class SamplingMetadata:
@staticmethod @staticmethod
def prepare( def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: list[SequenceGroupMetadata],
seq_lens: List[int], seq_lens: list[int],
query_lens: List[int], query_lens: list[int],
device: str, device: str,
pin_memory: bool, pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None, cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata": ) -> "SamplingMetadata":
( (
@ -195,16 +195,16 @@ class SamplingMetadata:
def _prepare_seq_groups( def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: list[SequenceGroupMetadata],
seq_lens: List[int], seq_lens: list[int],
query_lens: List[int], query_lens: list[int],
device: str, device: str,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None, cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[ ) -> tuple[
List[SequenceGroupToSample], list[SequenceGroupToSample],
List[int], list[int],
Dict[SamplingType, List[int]], dict[SamplingType, list[int]],
int, int,
]: ]:
"""Prepare sequence groups and indices for sampling. """Prepare sequence groups and indices for sampling.
@ -227,17 +227,17 @@ def _prepare_seq_groups(
num_prompts: Total number of prompts from `seq_group_metadata_list`. num_prompts: Total number of prompts from `seq_group_metadata_list`.
""" """
# Batched sequence groups for the current model forward stsep. # Batched sequence groups for the current model forward stsep.
seq_groups: List[SequenceGroupToSample] = [] seq_groups: list[SequenceGroupToSample] = []
# A list of token indices to sample/compute logprob. It is used to # A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance. # prune the outcome logits from the model for the performance.
selected_token_indices: List[int] = [] selected_token_indices: list[int] = []
# Used for selected_token_indices. # Used for selected_token_indices.
model_output_idx = 0 model_output_idx = 0
# Sampling type -> ( # Sampling type -> (
# indices to sample/prompt logprob within pruned output logits, # indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits) # indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[int]] = { categorized_sample_indices: dict[SamplingType, list[int]] = {
t: [] t: []
for t in SamplingType for t in SamplingType
} }
@ -265,9 +265,9 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None. # If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None seq_len: Optional[int] = None
query_len: Optional[int] = None query_len: Optional[int] = None
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
if cache is not None else []) if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices sample_indices: list[int] = (sample_obj.sample_indices
if cache is not None else []) if cache is not None else [])
do_sample = seq_group_metadata.do_sample do_sample = seq_group_metadata.do_sample
@ -389,16 +389,16 @@ class SamplingTensors:
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
) -> Tuple["SamplingTensors", bool, bool, bool]: ) -> tuple["SamplingTensors", bool, bool, bool]:
prompt_tokens: List[array] = [] prompt_tokens: list[array] = []
output_tokens: List[array] = [] output_tokens: list[array] = []
top_ks: List[int] = [] top_ks: list[int] = []
temperatures: List[float] = [] temperatures: list[float] = []
top_ps: List[float] = [] top_ps: list[float] = []
min_ps: List[float] = [] min_ps: list[float] = []
presence_penalties: List[float] = [] presence_penalties: list[float] = []
frequency_penalties: List[float] = [] frequency_penalties: list[float] = []
repetition_penalties: List[float] = [] repetition_penalties: list[float] = []
do_penalties = False do_penalties = False
do_top_p_top_k = False do_top_p_top_k = False
do_min_p = False do_min_p = False
@ -496,15 +496,15 @@ class SamplingTensors:
@classmethod @classmethod
def from_lists( def from_lists(
cls, cls,
temperatures: List[float], temperatures: list[float],
top_ps: List[float], top_ps: list[float],
top_ks: List[int], top_ks: list[int],
min_ps: List[float], min_ps: list[float],
presence_penalties: List[float], presence_penalties: list[float],
frequency_penalties: List[float], frequency_penalties: list[float],
repetition_penalties: List[float], repetition_penalties: list[float],
prompt_tokens: List[array], prompt_tokens: list[array],
output_tokens: List[array], output_tokens: list[array],
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Utils for model executor.""" """Utils for model executor."""
from typing import Any, Dict, Optional from typing import Any, Optional
import torch import torch
@ -12,7 +12,7 @@ def set_random_seed(seed: int) -> None:
def set_weight_attrs( def set_weight_attrs(
weight: torch.Tensor, weight: torch.Tensor,
weight_attrs: Optional[Dict[str, Any]], weight_attrs: Optional[dict[str, Any]],
): ):
"""Set attributes on a weight tensor. """Set attributes on a weight tensor.