[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Kyle Sayers 2025-02-05 00:32:06 -05:00 committed by GitHub
parent 686006a220
commit 7ff7a638b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 194 additions and 150 deletions

View File

@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type
import torch
from torch import nn
@ -59,6 +59,7 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:

View File

@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity
weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
if self.sparsity_scheme_map:
is_ignored = False
with suppress(ValueError):
is_ignored = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_ignore_list)
# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme
if not is_ignored:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
sparsity_scheme = self.sparsity_scheme_map[matched_target]
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False
def _handle_fused_layers(func):
"""
Decorator to handle fused layers by mapping vllm fused layer names
to their corresponding unfused layer names for quantization/pruning schemes.
"""
# fused_layer_name -> unfused_layer_name
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}
def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)
return fused_layer_handler
@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True)
or _match_fused_layer(layer_name, targets))
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None:
raise ValueError(
@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
return False
def _match_fused_layer(layer_name: str,
target_layers: Iterable[str]) -> Optional[str]:
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers.
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str,
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# Split into parent path and layer type
# e.g., "model.layers.0.self_attn" and "qkv_proj"
parent_path = ".".join(layer_name.split(".")[:-1])
layer_type = layer_name.split(".")[-1]
if layer_type not in FUSED_LAYER_NAME_MAPPING:
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
None)
if fused is None:
return None
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# Look for a target layer that:
# 1. Has the same parent path
# 2. Ends with one of the possible individual layer types
for target in target_layers:
is_same_parent = parent_path in target
is_matching_type = any(type_suffix in target
for type_suffix in possible_layer_types)
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
if is_same_parent and is_matching_type and all(
(f"{parent_path}.{type_suffix}" in target_layers)
for type_suffix in possible_layer_types):
return target
return None
return unfused_matches[0] if all(unfused_matches) else None

View File

@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"]
@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig):
module: torch.nn.Module) -> Dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [

View File

@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Any, Iterable, Optional
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
def deep_compare(dict1: Any, dict2: Any) -> bool:
@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional, Tuple
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
import numpy
import torch
@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None

View File

@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (ParamMapping,
configure_quant_config,
get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
@ -113,6 +114,9 @@ def _initialize_model(
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:

View File

@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
@ -138,3 +140,23 @@ class ParamMapping:
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(quant_config: QuantizationConfig,
model_class: Type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)

View File

@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.activation_func = SiluAndMul()
@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, hidden_states):
@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.embedding")
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = ChatGLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)

View File

@ -74,11 +74,13 @@ class Attention(nn.Module):
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
@ -101,6 +103,7 @@ class MLP(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
@ -109,11 +112,13 @@ class MLP(nn.Module):
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -137,7 +142,9 @@ class TransformerLayer(nn.Module):
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -164,7 +171,7 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
@ -181,6 +188,7 @@ class GLU(nn.Module):
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
@ -222,7 +230,8 @@ class GLU(nn.Module):
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
@ -230,12 +239,15 @@ class GLU(nn.Module):
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x):
x, _ = self.linear_proj(x)
@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module):
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,

View File

@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
instance_cls = _SUPPORT_VERSION.get(version)
if instance_cls is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config=vllm_config, prefix=prefix)
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)

View File

@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
prefix: str = "",
) -> QWenBaseModel:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "visual"): # noqa: SIM108
instance_cls = QWenVL
# Initialize LLM
else:
return QWenLLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = QWenLLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)