mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 00:24:27 +08:00
[Quant] [Bugfix] Fix quantization config matching with hf_to_vllm_mapper (#20046)
This commit is contained in:
parent
c05596f1a3
commit
9025a9a705
@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def __init__(self, num_bits: int = 8) -> None:
|
def __init__(self, num_bits: int = 8) -> None:
|
||||||
"""Initialize the quantization config."""
|
"""Initialize the quantization config."""
|
||||||
|
super().__init__()
|
||||||
self.num_bits = num_bits
|
self.num_bits = num_bits
|
||||||
|
|
||||||
def get_name(self) -> QuantizationMethods:
|
def get_name(self) -> QuantizationMethods:
|
||||||
|
|||||||
@ -805,7 +805,7 @@ def create_lora_manager(
|
|||||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||||
**kwargs) -> LoRAModelManager:
|
**kwargs) -> LoRAModelManager:
|
||||||
"""Create a LoRA adapter for a given model."""
|
"""Create a LoRA adapter for a given model."""
|
||||||
if not hasattr(model, "packed_modules_mapping"):
|
if not isinstance(model, SupportsLoRA):
|
||||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||||
lora_manager = lora_manager_cls(
|
lora_manager = lora_manager_cls(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@ -111,10 +111,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
||||||
# to ensure correct loading of lora weights.
|
# to ensure correct loading of lora weights.
|
||||||
model = self._adapter_manager.model
|
model = self._adapter_manager.model
|
||||||
hf_to_vllm_mapper = None
|
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
|
||||||
if (hasattr(model, "hf_to_vllm_mapper")
|
|
||||||
and model.hf_to_vllm_mapper is not None):
|
|
||||||
hf_to_vllm_mapper = model.hf_to_vllm_mapper
|
|
||||||
|
|
||||||
lora = self._lora_model_cls.from_local_checkpoint(
|
lora = self._lora_model_cls.from_local_checkpoint(
|
||||||
lora_path,
|
lora_path,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from torch import nn
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
else:
|
else:
|
||||||
QuantizationMethods = str
|
QuantizationMethods = str
|
||||||
|
|
||||||
@ -149,3 +150,15 @@ class QuantizationConfig(ABC):
|
|||||||
|
|
||||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def apply_vllm_mapper( # noqa: B027
|
||||||
|
self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
"""
|
||||||
|
Interface for models to update module names referenced in
|
||||||
|
quantization configs in order to reflect the vllm model structure
|
||||||
|
|
||||||
|
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||||
|
structure of the qconfig) to vllm model structure
|
||||||
|
"""
|
||||||
|
# TODO (@kylesayrs): add implementations for all subclasses
|
||||||
|
pass
|
||||||
|
|||||||
@ -63,6 +63,7 @@ class BitBLASConfig(QuantizationConfig):
|
|||||||
# (since we have only one group per output channel)
|
# (since we have only one group per output channel)
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
self.weight_bits = weight_bits
|
self.weight_bits = weight_bits
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.config import (CompressionFormat,
|
from compressed_tensors.config import (CompressionFormat,
|
||||||
@ -37,6 +37,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
|||||||
cutlass_fp4_supported)
|
cutlass_fp4_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsLinearMethod"]
|
__all__ = ["CompressedTensorsLinearMethod"]
|
||||||
@ -80,6 +83,18 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_name(self) -> QuantizationMethods:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "compressed-tensors"
|
return "compressed-tensors"
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||||
|
self.target_scheme_map)
|
||||||
|
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
|
||||||
|
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||||
|
self.sparsity_scheme_map)
|
||||||
|
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
|
||||||
|
self.sparsity_ignore_list)
|
||||||
|
if self.kv_cache_scheme is not None:
|
||||||
|
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
|
||||||
|
self.kv_cache_scheme)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -39,6 +39,9 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -100,6 +103,11 @@ class Fp8Config(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> list[str]:
|
def get_config_filenames(cls) -> list[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
if self.ignored_layers is not None:
|
||||||
|
self.ignored_layers = hf_to_vllm_mapper.apply_list(
|
||||||
|
self.ignored_layers)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
|
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
|
|||||||
@ -81,6 +81,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
|||||||
# (since we have only one group per output channel)
|
# (since we have only one group per output channel)
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
self.weight_bits = weight_bits
|
self.weight_bits = weight_bits
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
|
|||||||
@ -32,6 +32,8 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
lm_head_quantized: bool,
|
lm_head_quantized: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
# Group size for the quantization.
|
# Group size for the quantization.
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.lm_head_quantized = lm_head_quantized
|
self.lm_head_quantized = lm_head_quantized
|
||||||
|
|||||||
@ -181,6 +181,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
exclude_modules: list[str],
|
exclude_modules: list[str],
|
||||||
group_size: int = 16,
|
group_size: int = 16,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||||
if is_checkpoint_nvfp4_serialized:
|
if is_checkpoint_nvfp4_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@ -55,6 +55,7 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.torchao_config = torchao_config
|
self.torchao_config = torchao_config
|
||||||
self.skip_modules = skip_modules or []
|
self.skip_modules = skip_modules or []
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry
|
|||||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||||
as_embedding_model,
|
as_embedding_model,
|
||||||
as_reward_model)
|
as_reward_model)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,
|
|||||||
|
|
||||||
Note that model attributes are passed by reference to quant_config,
|
Note that model attributes are passed by reference to quant_config,
|
||||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||||
|
|
||||||
|
Once the `SupportsQuant` mixin has been added to all models, this
|
||||||
|
function can be removed
|
||||||
"""
|
"""
|
||||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
if not issubclass(model_class, SupportsQuant):
|
||||||
if packed_mapping is not None:
|
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||||
# pass packed_modules_mapping by reference to quant_config
|
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||||
quant_config.packed_modules_mapping = packed_mapping
|
|
||||||
else:
|
# pass mappings by reference to quant_config
|
||||||
logger.warning(
|
if hf_to_vllm_mapper is not None:
|
||||||
"The model class %s has not defined `packed_modules_mapping`, "
|
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||||
"this may lead to incorrect mapping of quantized or ignored "
|
if packed_mapping is not None:
|
||||||
"modules", model_class.__name__)
|
quant_config.packed_modules_mapping = packed_mapping
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from .interfaces_base import is_pooling_model
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool:
|
|||||||
class SupportsQuant:
|
class SupportsQuant:
|
||||||
"""The interface required for all models that support quantization."""
|
"""The interface required for all models that support quantization."""
|
||||||
|
|
||||||
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
|
hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
|
||||||
|
packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
|
||||||
quant_config: Optional[QuantizationConfig] = None
|
quant_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs) -> Self:
|
def __new__(cls, *args, **kwargs) -> Self:
|
||||||
instance = super().__new__(cls)
|
instance = super().__new__(cls)
|
||||||
|
|
||||||
|
# find config passed in arguments
|
||||||
quant_config = cls._find_quant_config(*args, **kwargs)
|
quant_config = cls._find_quant_config(*args, **kwargs)
|
||||||
if quant_config is not None:
|
if quant_config is not None:
|
||||||
|
|
||||||
|
# attach config to model for general use
|
||||||
instance.quant_config = quant_config
|
instance.quant_config = quant_config
|
||||||
instance.quant_config.packed_modules_mapping.update(
|
|
||||||
cls.packed_modules_mapping)
|
# apply model mappings to config for proper config-model matching
|
||||||
|
# NOTE: `TransformersForCausalLM` is not supported due to how this
|
||||||
|
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
|
||||||
|
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
|
||||||
|
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
|
||||||
|
instance.quant_config.apply_vllm_mapper(
|
||||||
|
instance.hf_to_vllm_mapper)
|
||||||
|
if getattr(instance, "packed_modules_mapping", None) is not None:
|
||||||
|
instance.quant_config.packed_modules_mapping.update(
|
||||||
|
instance.packed_modules_mapping)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
|
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
|
||||||
|
"""Find quant config passed through model constructor args"""
|
||||||
from vllm.config import VllmConfig # avoid circular import
|
from vllm.config import VllmConfig # avoid circular import
|
||||||
|
|
||||||
args_values = list(args) + list(kwargs.values())
|
args_values = list(args) + list(kwargs.values())
|
||||||
|
|||||||
@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||||
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||||
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||||
apply_rotary_pos_emb_vision)
|
apply_rotary_pos_emb_vision)
|
||||||
@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
|||||||
info=Qwen2_5_VLProcessingInfo,
|
info=Qwen2_5_VLProcessingInfo,
|
||||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||||
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA, SupportsPP):
|
SupportsLoRA, SupportsPP,
|
||||||
|
SupportsQuant):
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
multimodal_config = vllm_config.model_config.multimodal_config
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.visual = Qwen2_5_VisionTransformer(
|
self.visual = Qwen2_5_VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
quant_config=self._maybe_ignore_quant_config(self.quant_config),
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
|
||||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||||
# seems to avoid vision encoder sections for some models.
|
# seems to avoid vision encoder sections for some models.
|
||||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
|
||||||
return None
|
return None
|
||||||
return quant_config
|
return config
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||||
name: str) -> torch.Tensor:
|
name: str) -> torch.Tensor:
|
||||||
|
|||||||
@ -467,6 +467,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
|||||||
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
|
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
|
||||||
# this makes thing complicated. We need to remove this mapper after refactor
|
# this makes thing complicated. We need to remove this mapper after refactor
|
||||||
# `TransformersModel` in the future.
|
# `TransformersModel` in the future.
|
||||||
|
# NOTE: `SupportsQuant` can be updated after property decorator is removed
|
||||||
@property
|
@property
|
||||||
def hf_to_vllm_mapper(self):
|
def hf_to_vllm_mapper(self):
|
||||||
prefix_mapper = {
|
prefix_mapper = {
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Literal, Optional, Protocol, Union, overload
|
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -64,6 +64,19 @@ class WeightsMapper:
|
|||||||
return ((out_name, data) for name, data in weights
|
return ((out_name, data) for name, data in weights
|
||||||
if (out_name := self._map_name(name)) is not None)
|
if (out_name := self._map_name(name)) is not None)
|
||||||
|
|
||||||
|
def apply_list(self, values: list[str]) -> list[str]:
|
||||||
|
return [
|
||||||
|
out_name for name in values
|
||||||
|
if (out_name := self._map_name(name)) is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
out_name: value
|
||||||
|
for name, value in values.items()
|
||||||
|
if (out_name := self._map_name(name)) is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AutoWeightsLoader:
|
class AutoWeightsLoader:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -58,7 +58,8 @@ def _make_synced_weight_loader(original_weight_loader):
|
|||||||
|
|
||||||
|
|
||||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||||
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))
|
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||||
|
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||||
|
|
||||||
# don't infer mapping if the model has defined it explicitly.
|
# don't infer mapping if the model has defined it explicitly.
|
||||||
if parent_map:
|
if parent_map:
|
||||||
@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
|||||||
|
|
||||||
# We only check main components instead of whole model submodules
|
# We only check main components instead of whole model submodules
|
||||||
for child in model.children():
|
for child in model.children():
|
||||||
child_map = getattr(child, "packed_modules_mapping", {})
|
child_map = getattr(child, "packed_modules_mapping", None)
|
||||||
|
child_map = copy.deepcopy(child_map) if child_map is not None else {}
|
||||||
|
|
||||||
if any((k in parent_map and parent_map[k] != v)
|
if any((k in parent_map and parent_map[k] != v)
|
||||||
for k, v in child_map.items()):
|
for k, v in child_map.items()):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user