[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Kyle Sayers 2025-02-05 00:32:06 -05:00 committed by GitHub
parent 686006a220
commit 7ff7a638b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 194 additions and 150 deletions

View File

@ -2,7 +2,7 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Mapping, Optional, Type
import torch import torch
from torch import nn from torch import nn
@ -59,6 +59,7 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:

View File

@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names # TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore): if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity # Will be empty for models with only sparsity
weight_quant = input_quant = None weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map: if self.target_scheme_map:
matched_target = find_matched_target( matched_target = find_matched_target(
layer_name=layer_name, layer_name=layer_name,
module=layer, module=layer,
targets=self.target_scheme_map.keys()) targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target] scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights") weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations") input_quant = scheme_dict.get("input_activations")
if self.sparsity_scheme_map: # Find the sparsity scheme of the layer
is_ignored = False # assume that fused layers inerhit first component's sparsity scheme
with suppress(ValueError): sparsity_targets = (self.sparsity_scheme_map.keys() -
is_ignored = find_matched_target( set(self.sparsity_ignore_list))
layer_name=layer_name, sparsity_scheme: Optional[SparsityCompressionConfig] = None
module=layer, with suppress(ValueError):
targets=self.sparsity_ignore_list) matched_target = find_matched_target(
layer_name=layer_name,
# if the layer is in the sparsity ignore list, module=layer,
# we should not apply any sparsity scheme targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
if not is_ignored: sparsity_scheme = self.sparsity_scheme_map[matched_target]
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
if self.supports_cutlass_24(weight_quant=weight_quant, if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant, input_quant=input_quant,

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from typing import Iterable, Optional from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def is_activation_quantization_format(format: str) -> bool: def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [ _ACTIVATION_QUANTIZATION_FORMATS = [
@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str], def should_ignore_layer(
ignore: Iterable[str]) -> bool: layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None: if layer_name is None:
return False return False
@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name # in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that # from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme. # each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore: if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [
@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False return False
def _handle_fused_layers(func): def find_matched_target(
""" layer_name: Optional[str],
Decorator to handle fused layers by mapping vllm fused layer names module: Module,
to their corresponding unfused layer names for quantization/pruning schemes. targets: Iterable[str],
""" fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
# fused_layer_name -> unfused_layer_name ) -> str:
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}
def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)
return fused_layer_handler
@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
""" """
Helper function to look up which "target" in the compressed-tensors Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to. config that a layer corresponds to.
@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target First, we try to match the layer_name with a target
Second, we try to match the module's name with a target Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name :param layer_name: layer name
:param module: torch.nn.Module :param module: torch.nn.Module
:param targets: list of targets to match the layer against :param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
""" """
if layer_name is None: if layer_name is None:
layer_name = "" layer_name = ""
matched_target = (_find_first_match(layer_name, targets) matched_target = (
or _find_first_match(module.__class__.__name__, targets, _find_first_match(layer_name, targets)
True) or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets)) or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None: if matched_target is None:
raise ValueError( raise ValueError(
@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
return False return False
def _match_fused_layer(layer_name: str, def _match_fused_layer(
target_layers: Iterable[str]) -> Optional[str]: layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
""" """
Match a fused layer name to its corresponding individual layer in Match a fused layer name to its corresponding individual layer in
target_layers. target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples: Examples:
layer_name = "model.layers.0.self_attn.qkv_proj" layer_name = "model.layers.0.self_attn.qkv_proj"
@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str,
"model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"] "model.layers.0.self_attn.v_proj"]
""" """
# Split into parent path and layer type # find layer_name in mapping
# e.g., "model.layers.0.self_attn" and "qkv_proj" fused = next((key for key in fused_mapping if layer_name.endswith(key)),
parent_path = ".".join(layer_name.split(".")[:-1]) None)
layer_type = layer_name.split(".")[-1] if fused is None:
if layer_type not in FUSED_LAYER_NAME_MAPPING:
return None return None
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type] # expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# Look for a target layer that: # for each unfused component, find a match in targets
# 1. Has the same parent path unfused_matches: List[Optional[str]] = []
# 2. Ends with one of the possible individual layer types for unfused in unfused_paths:
for target in target_layers: for target in target_layers:
is_same_parent = parent_path in target if _is_equal_or_regex_match(unfused, target):
is_matching_type = any(type_suffix in target unfused_matches.append(target)
for type_suffix in possible_layer_types) break
else:
unfused_matches.append(None)
if is_same_parent and is_matching_type and all( return unfused_matches[0] if all(unfused_matches) else None
(f"{parent_path}.{type_suffix}" in target_layers)
for type_suffix in possible_layer_types):
return target
return None

View File

@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import ( from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer) deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"] __all__ = ["QuarkLinearMethod"]
@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig):
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude")) exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers): if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig):
module: torch.nn.Module) -> Dict[str, Any]: module: torch.nn.Module) -> Dict[str, Any]:
proj_name = layer_name.split(".")[-1] proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING: if proj_name in self.packed_modules_mapping:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [

View File

@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from typing import Any, Iterable, Optional from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def deep_compare(dict1: Any, dict2: Any) -> bool: def deep_compare(dict1: Any, dict2: Any) -> bool:
@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2 return dict1 == dict2
def should_ignore_layer(layer_name: Optional[str], def should_ignore_layer(
ignore: Iterable[str]) -> bool: layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None: if layer_name is None:
return False return False
@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name # in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that # from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme. # each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING: if proj_name in fused_mapping:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import List, Optional, Tuple from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
import numpy import numpy
import torch import torch
@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm) return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj # prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj # proj_name: q_proj
proj_name = prefix.split(".")[-1] proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [ shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] for shard_proj_name in fused_mapping[proj_name]
] ]
is_skipped = None is_skipped = None

View File

@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (ParamMapping, from vllm.model_executor.model_loader.utils import (ParamMapping,
configure_quant_config,
get_model_architecture, get_model_architecture,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -113,6 +114,9 @@ def _initialize_model(
model_config = vllm_config.model_config model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__) signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params: if "vllm_config" in all_params and "prefix" in all_params:

View File

@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.config import ModelConfig, ModelImpl from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model, from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model, as_embedding_model,
@ -138,3 +140,23 @@ class ParamMapping:
if module_name.endswith(key): if module_name.endswith(key):
return key, value return key, value
return None return None
def configure_quant_config(quant_config: QuantizationConfig,
model_class: Type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)

View File

@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
# MLP # MLP
self.mlp = GLMMLP(config, quant_config) self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.embedding")
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal): SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
# Initialize VL # Initialize VL
if hasattr(config, "vision_config"): if hasattr(config, "vision_config"): # noqa: SIM108
return ChatGLMV(vllm_config=vllm_config, prefix=prefix) instance_cls = ChatGLMV
# Initialize LLM # Initialize LLM
else: else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix) instance_cls = ChatGLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)

View File

@ -74,11 +74,13 @@ class Attention(nn.Module):
self.head_dim, self.head_dim,
config.num_heads, config.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
@ -101,6 +103,7 @@ class MLP(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -109,11 +112,13 @@ class MLP(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2",
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -137,7 +142,9 @@ class TransformerLayer(nn.Module):
self.attention = Attention(config, self.attention = Attention(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config) self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size, self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -164,7 +171,7 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
TransformerLayer(config, TransformerLayer(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
@ -181,6 +188,7 @@ class GLU(nn.Module):
config, config,
in_features, in_features,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
""" """
The original implementation is the same as: The original implementation is the same as:
@ -222,7 +230,8 @@ class GLU(nn.Module):
self.linear_proj = ReplicatedLinear(in_features, self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size) self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU() self.act1 = nn.GELU()
self.act2 = SiluAndMul() self.act2 = SiluAndMul()
@ -230,12 +239,15 @@ class GLU(nn.Module):
self.merged_proj = MergedColumnParallelLinear( self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2, config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, self.dense_4h_to_h = RowParallelLinear(
config.hidden_size, config.ffn_hidden_size,
bias=False, config.hidden_size,
quant_config=quant_config) bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x): def forward(self, x):
x, _ = self.linear_proj(x) x, _ = self.linear_proj(x)
@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module):
prefix=f"{prefix}.transformer") prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config, self.linear_proj = GLU(config,
in_features=config.hidden_size, in_features=config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size, out_channels=config.hidden_size,
kernel_size=2, kernel_size=2,

View File

@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
""" """
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
version = str(config.version).split(".") version = str(config.version).split(".")
version = tuple([int(x) for x in version]) version = tuple([int(x) for x in version])
# Dispatch class based on version # Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version) instance_cls = _SUPPORT_VERSION.get(version)
if instance_class is None: if instance_cls is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config=vllm_config, prefix=prefix)
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)

View File

@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
""" """
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
prefix: str = "", prefix: str = "",
) -> QWenBaseModel: ) -> QWenBaseModel:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
# Initialize VL # Initialize VL
if hasattr(config, "visual"): if hasattr(config, "visual"): # noqa: SIM108
return QWenVL(vllm_config=vllm_config, prefix=prefix) instance_cls = QWenVL
# Initialize LLM # Initialize LLM
else: else:
return QWenLLM(vllm_config=vllm_config, prefix=prefix) instance_cls = QWenLLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)