mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 22:11:56 +08:00
[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
686006a220
commit
7ff7a638b6
@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Mapping, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -59,6 +59,7 @@ def method_has_implemented_embedding(
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
packed_modules_mapping: Mapping[str, List[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(prefix, ignore=self.ignore):
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=self.ignore,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys())
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
|
||||
if self.sparsity_scheme_map:
|
||||
is_ignored = False
|
||||
with suppress(ValueError):
|
||||
is_ignored = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_ignore_list)
|
||||
|
||||
# if the layer is in the sparsity ignore list,
|
||||
# we should not apply any sparsity scheme
|
||||
|
||||
if not is_ignored:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
|
||||
# Find the sparsity scheme of the layer
|
||||
# assume that fused layers inerhit first component's sparsity scheme
|
||||
sparsity_targets = (self.sparsity_scheme_map.keys() -
|
||||
set(self.sparsity_ignore_list))
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=sparsity_targets,
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from typing import Iterable, Optional
|
||||
from types import MappingProxyType
|
||||
from typing import Iterable, List, Mapping, Optional
|
||||
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
FUSED_LAYER_NAME_MAPPING)
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(layer_name: Optional[str],
|
||||
ignore: Iterable[str]) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
|
||||
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
|
||||
return False
|
||||
|
||||
|
||||
def _handle_fused_layers(func):
|
||||
"""
|
||||
Decorator to handle fused layers by mapping vllm fused layer names
|
||||
to their corresponding unfused layer names for quantization/pruning schemes.
|
||||
"""
|
||||
# fused_layer_name -> unfused_layer_name
|
||||
fused_layer_map = {
|
||||
"qkv_proj": "q_proj",
|
||||
"gate_up_proj": "up_proj",
|
||||
}
|
||||
|
||||
def fused_layer_handler(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> Optional[str]:
|
||||
"""
|
||||
Wrapper function specifically designed to support the
|
||||
find_matched_target function.
|
||||
|
||||
It handles cases where the provided layer name corresponds to a
|
||||
fused layer in vllm, mapping it to its equivalent unfused layer name
|
||||
based on the predefined fused_layer_map. If the original layer name
|
||||
raises a ValueError in the wrapped function, this handler
|
||||
will attempt to resolve the issue by substituting with unfused
|
||||
layer name.
|
||||
|
||||
:param layer_name: Name of the layer, which may be fused.
|
||||
:param module: An instance of torch.nn.Module.
|
||||
:param targets: A list of target names or patterns to match.
|
||||
:return: The result of the wrapped find_matched_target function with
|
||||
the resolved layer name.
|
||||
:raises ValueError: If the layer name cannot be resolved to a
|
||||
valid target.
|
||||
"""
|
||||
try:
|
||||
return func(layer_name, module, targets)
|
||||
except ValueError:
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
|
||||
unfused_proj_name = fused_layer_map.get(fused_proj_name,
|
||||
fused_proj_name)
|
||||
new_layer_name = f"{parent_name}.{unfused_proj_name}"
|
||||
return func(new_layer_name, module, targets)
|
||||
|
||||
return fused_layer_handler
|
||||
|
||||
|
||||
@_handle_fused_layers
|
||||
def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> str:
|
||||
def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets,
|
||||
True)
|
||||
or _match_fused_layer(layer_name, targets))
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(layer_name: str,
|
||||
target_layers: Iterable[str]) -> Optional[str]:
|
||||
def _match_fused_layer(
|
||||
layer_name: str, target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers.
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str,
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# Split into parent path and layer type
|
||||
# e.g., "model.layers.0.self_attn" and "qkv_proj"
|
||||
parent_path = ".".join(layer_name.split(".")[:-1])
|
||||
layer_type = layer_name.split(".")[-1]
|
||||
|
||||
if layer_type not in FUSED_LAYER_NAME_MAPPING:
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
|
||||
None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# Look for a target layer that:
|
||||
# 1. Has the same parent path
|
||||
# 2. Ends with one of the possible individual layer types
|
||||
for target in target_layers:
|
||||
is_same_parent = parent_path in target
|
||||
is_matching_type = any(type_suffix in target
|
||||
for type_suffix in possible_layer_types)
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: List[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
if is_same_parent and is_matching_type and all(
|
||||
(f"{parent_path}.{type_suffix}" in target_layers)
|
||||
for type_suffix in possible_layer_types):
|
||||
return target
|
||||
|
||||
return None
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
|
||||
@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
FUSED_LAYER_NAME_MAPPING)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix, ignore=exclude_layers):
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
module: torch.nn.Module) -> Dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
FUSED_LAYER_NAME_MAPPING)
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Mapping, Optional
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(layer_name: Optional[str],
|
||||
ignore: Iterable[str]) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List, Optional, Tuple
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
|
||||
serialize_vllm_model, tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
configure_quant_config,
|
||||
get_model_architecture,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@ -113,6 +114,9 @@ def _initialize_model(
|
||||
model_config = vllm_config.model_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
|
||||
@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
|
||||
from vllm.config import ModelConfig, ModelImpl
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
@ -138,3 +140,23 @@ class ParamMapping:
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: Type[nn.Module]):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
"""
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
if packed_mapping is not None:
|
||||
# pass packed_modules_mapping by reference to quant_config
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
else:
|
||||
logger.warning(
|
||||
"The model class %s has not defined `packed_modules_mapping`, "
|
||||
"this may lead to incorrect mapping of quantized or ignored "
|
||||
"modules", model_class.__name__)
|
||||
|
||||
@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embedding")
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "vision_config"):
|
||||
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||
if hasattr(config, "vision_config"): # noqa: SIM108
|
||||
instance_cls = ChatGLMV
|
||||
# Initialize LLM
|
||||
else:
|
||||
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
|
||||
instance_cls = ChatGLM
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@ -74,11 +74,13 @@ class Attention(nn.Module):
|
||||
self.head_dim,
|
||||
config.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
||||
@ -101,6 +103,7 @@ class MLP(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -109,11 +112,13 @@ class MLP(nn.Module):
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -137,7 +142,9 @@ class TransformerLayer(nn.Module):
|
||||
self.attention = Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.mlp = MLP(config, quant_config=quant_config)
|
||||
self.mlp = MLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -164,7 +171,7 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -181,6 +188,7 @@ class GLU(nn.Module):
|
||||
config,
|
||||
in_features,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
"""
|
||||
The original implementation is the same as:
|
||||
@ -222,7 +230,8 @@ class GLU(nn.Module):
|
||||
self.linear_proj = ReplicatedLinear(in_features,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size)
|
||||
self.act1 = nn.GELU()
|
||||
self.act2 = SiluAndMul()
|
||||
@ -230,12 +239,15 @@ class GLU(nn.Module):
|
||||
self.merged_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size, [config.ffn_hidden_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_proj")
|
||||
|
||||
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h")
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.linear_proj(x)
|
||||
@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module):
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.linear_proj = GLU(config,
|
||||
in_features=config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=2,
|
||||
|
||||
@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
version = str(config.version).split(".")
|
||||
version = tuple([int(x) for x in version])
|
||||
# Dispatch class based on version
|
||||
instance_class = _SUPPORT_VERSION.get(version)
|
||||
if instance_class is None:
|
||||
instance_cls = _SUPPORT_VERSION.get(version)
|
||||
if instance_cls is None:
|
||||
raise ValueError(
|
||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||
return instance_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
prefix: str = "",
|
||||
) -> QWenBaseModel:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "visual"):
|
||||
return QWenVL(vllm_config=vllm_config, prefix=prefix)
|
||||
if hasattr(config, "visual"): # noqa: SIM108
|
||||
instance_cls = QWenVL
|
||||
# Initialize LLM
|
||||
else:
|
||||
return QWenLLM(vllm_config=vllm_config, prefix=prefix)
|
||||
instance_cls = QWenLLM
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user