mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 06:07:02 +08:00
Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
a9944aabfa
commit
07ad27121f
@ -71,15 +71,15 @@ exclude = [
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||
# Python 3.8 typing - skip V0 code
|
||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
# Python 3.8 typing - skip utils for ROCm
|
||||
"vllm/utils.py" = ["UP006", "UP035"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@ -6,7 +6,8 @@ import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -49,21 +50,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: List[str] = []
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: List[str] = []
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: List[str] = []
|
||||
self.target_modules: list[str] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: List[str],
|
||||
allowed_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], str]:
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
@ -95,7 +96,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> Tuple[List[str], bool]:
|
||||
revision: Optional[str]) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
@ -155,7 +156,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
revision: Optional[str],
|
||||
pre_quant: bool,
|
||||
load_8bit: bool,
|
||||
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
@ -175,7 +176,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict: Dict[str, Any] = {}
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if pre_quant:
|
||||
if load_8bit:
|
||||
@ -257,7 +258,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: Dict) -> QuantState:
|
||||
temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
@ -415,7 +416,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
self._get_bnb_target_modules(model)
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
@ -480,7 +481,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
@ -3,7 +3,8 @@ import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from typing import Generator, Iterable, List, Optional, Tuple, cast
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional, cast
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
@ -92,7 +93,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: Optional[list[str]],
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
@ -138,7 +139,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
@ -173,7 +174,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
@ -238,7 +239,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Dict, Generator, Tuple
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@ -84,8 +84,8 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
return gguf_quant_weights_iterator(model_name_or_path,
|
||||
gguf_to_hf_name_map)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import ast
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -33,7 +33,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
@ -146,7 +146,7 @@ class NeuronSpeculationCausalLM(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
@ -188,7 +188,7 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
|
||||
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
|
||||
env_value = os.getenv(env)
|
||||
if env_value is None:
|
||||
return default_value
|
||||
@ -464,7 +464,7 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig,
|
||||
|
||||
draft_model.eval()
|
||||
|
||||
token_tree: Dict[int, List[int]] = ast.literal_eval(
|
||||
token_tree: dict[int, list[int]] = ast.literal_eval(
|
||||
speculation_config.speculative_token_tree)
|
||||
|
||||
speculation_model = EagleSpeculativeDecoder(draft_model.model,
|
||||
|
||||
@ -9,7 +9,7 @@ import importlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -46,7 +46,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
}
|
||||
|
||||
# Models supported by Neuronx distributed for inference.
|
||||
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = {
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
|
||||
"LlamaForCausalLM":
|
||||
("neuronx_distributed_inference.models.llama.modeling_llama",
|
||||
"NeuronLlamaForCausalLM"),
|
||||
@ -365,7 +365,7 @@ class NeuronSpeculationCausalLM(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
# ruff: noqa: SIM117
|
||||
import glob
|
||||
import os
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -48,7 +49,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> List[str]:
|
||||
revision: Optional[str]) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
@ -87,7 +88,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -48,12 +49,12 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
|
||||
tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list))
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
@ -63,7 +64,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: Dict[str, torch.Tensor] = {}
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
@ -160,7 +161,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
return model.eval()
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.runai_model_streamer:
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
@ -188,7 +189,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: Dict[str, torch.Tensor] = {}
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
|
||||
@ -6,9 +6,10 @@ import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
|
||||
from typing import BinaryIO, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -67,7 +68,7 @@ class TensorizerConfig:
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[Type[torch.nn.Module]] = None
|
||||
model_class: Optional[type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
_is_sharded: bool = False
|
||||
@ -365,7 +366,7 @@ class TensorizerAgent:
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning("Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from typing import Generator, Tuple
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,7 +36,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
self, ) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -124,7 +124,7 @@ def device_loading_context(module: torch.nn.Module,
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: Dict[str, torch.device] = {}
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
@ -214,7 +214,7 @@ def resolve_transformers_arch(model_config: ModelConfig,
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
# Special handling for quantized Mixtral.
|
||||
@ -257,8 +257,8 @@ class ParamMapping:
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: Dict[str, List[str]]
|
||||
inverse_packed_mapping: Dict[str, Tuple[str,
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -273,7 +273,7 @@ class ParamMapping:
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[Tuple[str, List[str]]]:
|
||||
module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
@ -281,7 +281,7 @@ class ParamMapping:
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: Type[nn.Module]):
|
||||
model_class: type[nn.Module]):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
@ -8,8 +8,9 @@ import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
@ -221,7 +222,7 @@ def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
@ -253,9 +254,9 @@ def get_sparse_attention_config(
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
allow_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None,
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
@ -263,11 +264,11 @@ def download_weights_from_hf(
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (List[str]): The allowed patterns for the
|
||||
allow_patterns (list[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
|
||||
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
@ -347,9 +348,9 @@ def download_safetensors_index_file_from_hf(
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: list[str],
|
||||
hf_folder: str,
|
||||
index_file: str) -> List[str]:
|
||||
index_file: str) -> list[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
@ -372,7 +373,7 @@ def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: List[str]) -> List[str]:
|
||||
hf_weights_files: list[str]) -> list[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
@ -408,9 +409,9 @@ def np_cache_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
hf_folder: str,
|
||||
hf_weights_files: List[str],
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
@ -424,7 +425,7 @@ def np_cache_weights_iterator(
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: List[str] = []
|
||||
weight_names: list[str] = []
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
@ -453,9 +454,9 @@ def np_cache_weights_iterator(
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
@ -470,9 +471,9 @@ def safetensors_weights_iterator(
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
@ -486,9 +487,9 @@ def runai_safetensors_weights_iterator(
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files
|
||||
using fastsafetensor library."""
|
||||
if torch.distributed.is_initialized():
|
||||
@ -525,10 +526,10 @@ def fastsafetensors_weights_iterator(
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
@ -544,7 +545,7 @@ def pt_weights_iterator(
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
@ -553,8 +554,8 @@ def get_gguf_extra_tensor_names(
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user