mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 00:09:40 +08:00
[Quantization] Add field to skip unquantized modules for GPTQ config (#25455)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
129a643b4c
commit
d70c154975
@ -270,6 +270,7 @@ class VllmConfig:
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
quant_config.maybe_update_config(model_config.model)
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
|
||||
@ -162,3 +162,9 @@ class QuantizationConfig(ABC):
|
||||
"""
|
||||
# TODO (@kylesayrs): add implementations for all subclasses
|
||||
pass
|
||||
|
||||
def maybe_update_config(self, model_name: str): # noqa: B027
|
||||
"""
|
||||
Interface to update values after config initialization.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -7,6 +7,7 @@ from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -22,6 +23,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
@ -38,6 +41,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
@ -75,15 +79,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
@ -114,8 +123,10 @@ class GPTQConfig(QuantizationConfig):
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic, autoround_version)
|
||||
dynamic, autoround_version, modules_in_block_to_quantize)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
@ -136,6 +147,35 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize)
|
||||
|
||||
def maybe_update_config(self,
|
||||
model_name: str,
|
||||
revision: Optional[str] = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name,
|
||||
revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get('dtype', None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
@ -35,6 +36,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -71,10 +74,16 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
@ -121,15 +130,19 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = full_config.get("autoround_version", "")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
return (
|
||||
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
@ -158,8 +171,11 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized, dynamic, config)
|
||||
lm_head_quantized, dynamic, config,
|
||||
modules_in_block_to_quantize)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
@ -223,6 +239,35 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize)
|
||||
|
||||
def maybe_update_config(self,
|
||||
model_name: str,
|
||||
revision: Optional[str] = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name,
|
||||
revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get('dtype', None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex as re
|
||||
@ -70,6 +72,49 @@ def get_dynamic_override(
|
||||
return default_value
|
||||
|
||||
|
||||
def is_layer_gptq_quantized(
|
||||
prefix: str,
|
||||
quantized_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
||||
# GPTQ's `modules_in_block_to_quantize`:
|
||||
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
|
||||
# Full prefix ["model.layers.0.self_attn.q_proj"]
|
||||
|
||||
proj_name = prefix.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_quantized = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_quantized = any(layer in shard_prefix
|
||||
for layer in quantized_layers)
|
||||
|
||||
if is_quantized is None:
|
||||
is_quantized = is_shard_quantized
|
||||
elif is_shard_quantized != is_quantized:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_quantized = any(layer in prefix for layer in quantized_layers)
|
||||
|
||||
assert is_quantized is not None
|
||||
return is_quantized
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
@ -80,10 +125,15 @@ def get_linear_quant_method(
|
||||
parallel_lm_head_quantized = isinstance(
|
||||
layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
is_layer_quantized = is_layer_gptq_quantized(
|
||||
prefix=prefix,
|
||||
quantized_layers=cloned_config.modules_in_block_to_quantize,
|
||||
fused_mapping=cloned_config.packed_modules_mapping)
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix) == False: # noqa: E712
|
||||
layer_name=prefix) == False or (
|
||||
not is_layer_quantized): # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
@ -25,9 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -1281,11 +1278,6 @@ class BaseKeyeModule(nn.Module):
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||
@ -1297,14 +1289,14 @@ class BaseKeyeModule(nn.Module):
|
||||
|
||||
self.visual = KeyeSiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
self.mlp_AR = self._build_projector(
|
||||
config,
|
||||
config.vision_config,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "mlp_AR"),
|
||||
)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||
WhisperAttention,
|
||||
@ -36,10 +36,6 @@ from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||
WhisperEncoder)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
@ -548,36 +544,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
# See: https://huggingface.co/openbmb/MiniCPM-o-2_6-int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
# MiniCPMO GPTQ model leave vpm unquantized.
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
return super().init_vision_module(config, quant_config, prefix)
|
||||
|
||||
def init_resampler(
|
||||
self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
# MiniCPMO GPTQ model leave resampler unquantized.
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
return super().init_resampler(embed_dim, vision_dim, quant_config,
|
||||
prefix)
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
|
||||
@ -31,9 +31,6 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.models.aimv2 import AIMv2Model
|
||||
from vllm.model_executor.models.siglip import SiglipVisionModel
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
||||
@ -416,7 +413,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.visual_tokenizer_config,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
|
||||
@ -430,14 +427,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
# See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
@ -52,9 +52,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -1015,8 +1012,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(
|
||||
self.quant_config),
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
@ -1032,13 +1028,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
|
||||
@ -50,9 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -1270,7 +1267,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
@ -1286,14 +1283,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
# See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
|
||||
@ -46,9 +46,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -149,24 +146,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid gate quantization while AutoRound does.
|
||||
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
|
||||
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
|
||||
if isinstance(
|
||||
quant_config,
|
||||
(GPTQConfig,
|
||||
GPTQMarlinConfig)) and not quant_config.autoround_version:
|
||||
return None
|
||||
return quant_config
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert hidden_states.dim(
|
||||
@ -699,4 +683,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
@ -41,9 +41,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -119,12 +116,11 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=f"{prefix}.gate")
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen3NextMLP(
|
||||
@ -142,16 +138,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
1,
|
||||
bias=False)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid gate quantization while AutoRound does.
|
||||
if isinstance(
|
||||
quant_config,
|
||||
(GPTQConfig,
|
||||
GPTQMarlinConfig)) and not quant_config.autoround_version:
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
|
||||
@ -50,9 +50,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -1058,7 +1055,7 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
@ -1116,13 +1113,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
|
||||
@ -322,7 +322,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from functools import cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||
@ -27,7 +28,8 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.transformers_utils.utils import (check_gguf_file,
|
||||
parse_safetensors_file_metadata)
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@ -999,6 +1001,34 @@ def try_get_tokenizer_config(
|
||||
return None
|
||||
|
||||
|
||||
def get_safetensors_params_metadata(
|
||||
model: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the safetensors metadata for remote model repository.
|
||||
"""
|
||||
full_metadata = {}
|
||||
if (model_path := Path(model)).exists():
|
||||
safetensors_to_check = model_path.glob("*.safetensors")
|
||||
full_metadata = {
|
||||
param_name: info
|
||||
for file_path in safetensors_to_check if file_path.is_file()
|
||||
for param_name, info in parse_safetensors_file_metadata(
|
||||
file_path).items()
|
||||
}
|
||||
else:
|
||||
repo_mt = try_get_safetensors_metadata(model, revision=revision)
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
full_metadata = {
|
||||
param_name: asdict(info)
|
||||
for file_mt in files_mt.values()
|
||||
for param_name, info in file_mt.tensors.items()
|
||||
}
|
||||
return full_metadata
|
||||
|
||||
|
||||
def _download_mistral_config_file(model, revision) -> dict:
|
||||
config_file_name = "params.json"
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import struct
|
||||
from functools import cache
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
|
||||
from vllm.logger import init_logger
|
||||
@ -97,3 +98,11 @@ def maybe_model_redirect(model: str) -> str:
|
||||
return redirect_model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def parse_safetensors_file_metadata(
|
||||
path: Union[str, PathLike]) -> dict[str, Any]:
|
||||
with open(path, "rb") as f:
|
||||
length_of_metadata = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata = json.loads(f.read(length_of_metadata).decode('utf-8'))
|
||||
return metadata
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user