Update deprecated type hinting in vllm/transformers_utils (#18058)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-13 12:34:37 +01:00 committed by GitHub
parent ff334ca1cd
commit 8c946cecca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 98 additions and 102 deletions

View File

@ -6,7 +6,7 @@ import os
import time
from functools import cache
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Optional, Type, Union
from typing import Any, Callable, Literal, Optional, Union
import huggingface_hub
from huggingface_hub import hf_hub_download
@ -55,11 +55,11 @@ HF_TOKEN = os.getenv('HF_TOKEN', None)
logger = init_logger(__name__)
_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
_CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
"mllama": MllamaConfig
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
@ -199,7 +199,7 @@ def patch_rope_scaling(config: PretrainedConfig) -> None:
patch_rope_scaling_dict(rope_scaling)
def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
if "rope_type" in rope_scaling and "type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
rope_type_legacy = rope_scaling["type"]
@ -748,7 +748,7 @@ def get_hf_image_processor_config(
hf_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
# ModelScope does not provide an interface for image_processor
if VLLM_USE_MODELSCOPE:
return dict()

View File

@ -8,7 +8,7 @@
""" Arctic model configuration"""
from dataclasses import asdict, dataclass
from typing import Any, Dict
from typing import Any
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
@ -192,14 +192,14 @@ class ArcticConfig(PretrainedConfig):
)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig":
def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "ArcticConfig":
result = super().from_dict(config_dict, **kwargs)
config = result[0] if isinstance(result, tuple) else result
if isinstance(config.quantization, dict):
config.quantization = ArcticQuantizationConfig(**config.quantization)
return result
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
ret = super().to_dict()
if isinstance(ret["quantization"], ArcticQuantizationConfig):
ret["quantization"] = asdict(ret["quantization"])

View File

@ -61,7 +61,7 @@ class Cohere2Config(PretrainedConfig):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
rope_scaling (`dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
@ -86,11 +86,11 @@ class Cohere2Config(PretrainedConfig):
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268
from typing import Tuple
from transformers.configuration_utils import PretrainedConfig
@ -191,12 +190,12 @@ class DeepseekVLV2Config(PretrainedConfig):
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), )
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384), )
def __init__(self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: Tuple[Tuple[int,
candidate_resolutions: tuple[tuple[int,
int]] = ((384, 384), ),
**kwargs):
super().__init__(**kwargs)

View File

@ -17,14 +17,12 @@
# limitations under the License.
"""Exaone model configuration"""
from typing import Dict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {}
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: dict[str, str] = {}
class ExaoneConfig(PretrainedConfig):

View File

@ -98,7 +98,7 @@ class JAISConfig(PretrainedConfig):
Scale attention weights by dividing by hidden_size instead of
sqrt(hidden_size). Need to set scale_attn_weights to `True` as
well.
alibi_scaling (`Dict`, *optional*):
alibi_scaling (`dict`, *optional*):
Dictionary containing the scaling configuration for ALiBi
embeddings. Currently only supports linear
scaling strategy. Can specify either the scaling `factor` (must be
@ -108,7 +108,7 @@ class JAISConfig(PretrainedConfig):
formats are `{"type": strategy name, "factor": scaling factor}` or
`{"type": strategy name,
"train_seq_len": training sequence length}`.
architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']):
architectures (`list`, *optional*, defaults to ['JAISLMHeadModel']):
architecture names for Jais.
Example:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional
from transformers import PretrainedConfig
@ -17,7 +17,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
emb_dim: int = 4096,
inner_dim: int = 0,
n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None,
top_k_tokens_per_head: Optional[list[int]] = None,
n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
@ -34,7 +34,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
the inner dimension of the model. If 0, will be the emb_dim.
n_predict: int
the number of lookaheads for the speculator
top_k_tokens_per_head: List[int]
top_k_tokens_per_head: list[int]
Number of tokens to consider from each head when forming the
candidate tree.
For each candidate branch in the tree, head n produces topk[n]

View File

@ -4,11 +4,11 @@
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
"""A HuggingFace-style model configuration."""
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from transformers import PretrainedConfig
attn_config_defaults: Dict = {
attn_config_defaults: dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
@ -20,8 +20,8 @@ attn_config_defaults: Dict = {
'alibi': False,
'alibi_bias_max': 8
}
ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
init_config_defaults: Dict = {
ffn_config_defaults: dict = {'ffn_type': 'mptmlp'}
init_config_defaults: dict = {
'name': 'kaiming_normal_',
'fan_mode': 'fan_in',
'init_nonlinearity': 'relu',
@ -52,15 +52,15 @@ class MPTConfig(PretrainedConfig):
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: Dict = attn_config_defaults,
ffn_config: Dict = ffn_config_defaults,
attn_config: dict = attn_config_defaults,
ffn_config: dict = ffn_config_defaults,
init_device: str = 'cpu',
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
embedding_fraction: float = 1.0,
norm_type: str = 'low_precision_layernorm',
use_cache: bool = False,
init_config: Dict = init_config_defaults,
init_config: dict = init_config_defaults,
fc_type: str = 'torch',
verbose: Optional[int] = None,
**kwargs: Any):
@ -102,8 +102,8 @@ class MPTConfig(PretrainedConfig):
self._validate_config()
def _set_config_defaults(
self, config: Dict[str, Any],
config_defaults: Dict[str, Any]) -> Dict[str, Any]:
self, config: dict[str, Any],
config_defaults: dict[str, Any]) -> dict[str, Any]:
for (k, v) in config_defaults.items():
if k not in config:
config[k] = v

View File

@ -108,7 +108,7 @@ class SolarConfig(PretrainedConfig):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
rope_scaling (`dict`, *optional*):
Dictionary containing the scaling configuration for
the RoPE embeddings.
Currently supports two scaling

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
from typing import Any, Dict, Optional
from typing import Any, Optional
import transformers
@ -48,8 +48,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
def __init__(
self,
audio_config: Optional[Dict[str, Any]] = None,
text_config: Optional[Dict[str, Any]] = None,
audio_config: Optional[dict[str, Any]] = None,
text_config: Optional[dict[str, Any]] = None,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
@ -58,8 +58,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
stack_factor: int = 8,
norm_init: float = 0.4,
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
text_model_lora_config: Optional[dict[str, Any]] = None,
audio_model_lora_config: Optional[dict[str, Any]] = None,
projector_ln_mid: bool = False,
**kwargs,
):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional
from typing import Optional
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup)
@ -22,7 +22,7 @@ class Detokenizer:
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[
prompt_logprobs: list[Optional[dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
@ -49,7 +49,7 @@ class Detokenizer:
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: List[str] = []
next_iter_tokens: list[str] = []
prev_tokens = None
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
from .tokenizer import AnyTokenizer
def _replace_none_with_empty(tokens: List[Optional[str]]):
def _replace_none_with_empty(tokens: list[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
@ -13,7 +13,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
output_tokens: list[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
@ -22,8 +22,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
sub_texts: list[str] = []
current_sub_text: list[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
@ -52,9 +52,9 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
prompt_ids: list[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
) -> tuple[list[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
@ -76,8 +76,8 @@ def convert_prompt_ids_to_tokens(
def convert_ids_list_to_tokens(
tokenizer: AnyTokenizer,
token_ids: List[int],
) -> List[str]:
token_ids: list[int],
) -> list[str]:
"""Detokenize the input ids individually.
Args:
@ -98,13 +98,13 @@ def convert_ids_list_to_tokens(
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
all_input_ids: list[int],
prev_tokens: Optional[list[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
) -> tuple[list[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.

View File

@ -24,7 +24,6 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import math
from typing import List, Tuple
import torch
import torchvision.transforms as T
@ -36,8 +35,8 @@ from transformers.processing_utils import ProcessorMixin
class ImageTransform:
def __init__(self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
std: tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True):
self.mean = mean
self.std = std
@ -62,11 +61,11 @@ class DeepseekVLV2Processor(ProcessorMixin):
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
candidate_resolutions: tuple[tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
@ -170,13 +169,13 @@ class DeepseekVLV2Processor(ProcessorMixin):
return t
def decode(self, t: List[int], **kwargs) -> str:
def decode(self, t: list[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: List[Image.Image],
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
):
@ -184,8 +183,8 @@ class DeepseekVLV2Processor(ProcessorMixin):
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
conversations (list[dict]): conversations with a list of messages;
images (list[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
@ -196,7 +195,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
- num_image_tokens (list[int]): the number of image tokens
"""
assert (prompt is not None and images is not None
@ -257,7 +256,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
self,
*,
prompt: str,
images: List[Image.Image],
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
):
@ -265,7 +264,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
images (list[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
@ -274,7 +273,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
- num_image_tokens (list[int]): the number of image tokens
"""
prepare = self.process_one(
@ -288,7 +287,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
images: list[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,

View File

@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union
from typing import Union
import PIL
import torch
@ -102,7 +102,7 @@ class OvisProcessor(ProcessorMixin):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
**kwargs: Unpack[OvisProcessorKwargs],
) -> BatchFeature:
"""
@ -111,14 +111,14 @@ class OvisProcessor(ProcessorMixin):
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
@ -400,7 +400,7 @@ class OvisProcessor(ProcessorMixin):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.lora.request import LoRARequest
@ -32,7 +32,7 @@ class TokenizerGroup:
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[int],
encoded_tokens: list[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
@ -48,7 +48,7 @@ class TokenizerGroup:
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
@ -65,7 +65,7 @@ class TokenizerGroup:
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,

View File

@ -4,7 +4,7 @@ import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
@ -28,7 +28,7 @@ logger = init_logger(__name__)
@dataclass
class Encoding:
input_ids: Union[List[int], List[List[int]]]
input_ids: Union[list[int], list[list[int]]]
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
@ -105,7 +105,7 @@ def validate_request_params(request: "ChatCompletionRequest"):
"for Mistral tokenizers.")
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
@ -125,7 +125,7 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
return []
def find_tokenizer_file(files: List[str]):
def find_tokenizer_file(files: list[str]):
file_pattern = re.compile(
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
@ -145,10 +145,10 @@ def find_tokenizer_file(files: List[str]):
def make_mistral_chat_completion_request(
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str,
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str,
Any]]] = None) -> "ChatCompletionRequest":
last_message = cast(Dict[str, Any], messages[-1])
last_message = cast(dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
@ -199,7 +199,7 @@ class MistralTokenizer(TokenizerBase):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self._vocab = tokenizer_.vocab()
# Convert to a Dict[str, int] to match protocol, but this is a lossy
# Convert to a dict[str, int] to match protocol, but this is a lossy
# conversion. There may be multiple token ids that decode to the same
# string due to partial UTF-8 byte sequences being converted to <20>
self._vocab_dict = {
@ -314,21 +314,21 @@ class MistralTokenizer(TokenizerBase):
def __call__(
self,
text: Union[str, List[str], List[int]],
text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text
input_ids: Union[list[int], list[list[int]]]
# For list[str], original prompt text
if is_list_of(text, str):
input_ids_: List[List[int]] = []
input_ids_: list[list[int]] = []
for p in text:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
# For list[int], apply chat template output, already tokens.
elif is_list_of(text, int):
input_ids = text
# For str, single prompt text
@ -350,7 +350,7 @@ class MistralTokenizer(TokenizerBase):
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
) -> list[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(text)
@ -362,7 +362,7 @@ class MistralTokenizer(TokenizerBase):
text: str,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
@ -374,9 +374,9 @@ class MistralTokenizer(TokenizerBase):
return self.tokenizer.encode(text, bos=True, eos=False)
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> list[int]:
request = make_mistral_chat_completion_request(messages, tools)
encoded = self.mistral.encode_chat_completion(request)
@ -384,7 +384,7 @@ class MistralTokenizer(TokenizerBase):
# encode-decode to get clean prompt
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens
if self.is_tekken:
tokens = [
@ -417,7 +417,7 @@ class MistralTokenizer(TokenizerBase):
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
regular_tokens: List[str] = []
regular_tokens: list[str] = []
decoded_list = []
for token in tokens:
@ -442,7 +442,7 @@ class MistralTokenizer(TokenizerBase):
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def decode(self,
ids: Union[List[int], int],
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
assert (
skip_special_tokens
@ -454,9 +454,9 @@ class MistralTokenizer(TokenizerBase):
def convert_ids_to_tokens(
self,
ids: List[int],
ids: list[int],
skip_special_tokens: bool = True,
) -> List[str]:
) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# TODO(Patrick) - potentially allow special tokens to not be skipped

View File

@ -4,7 +4,7 @@ import json
from functools import cache
from os import PathLike
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
from vllm.logger import init_logger
@ -38,7 +38,7 @@ def modelscope_list_repo_files(
repo_id: str,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
) -> List[str]:
) -> list[str]:
"""List files in a modelscope repo."""
from modelscope.hub.api import HubApi
api = HubApi()