mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 15:37:10 +08:00
Updated load_weight for Siglip2VisionTransformer
Signed-off-by: Oscar Gonzalez <ogonzal6@alumni.jh.edu>
This commit is contained in:
parent
37a92d952b
commit
0dbe093c56
@ -18,9 +18,6 @@ import torch.nn.functional as F
|
|||||||
from transformers import PretrainedConfig, Qwen3Config
|
from transformers import PretrainedConfig, Qwen3Config
|
||||||
from transformers.image_processing_utils import BatchFeature
|
from transformers.image_processing_utils import BatchFeature
|
||||||
from transformers.tokenization_utils import TensorType
|
from transformers.tokenization_utils import TensorType
|
||||||
from transformers.models.siglip2.modeling_siglip2 import (
|
|
||||||
Siglip2MLP,
|
|
||||||
)
|
|
||||||
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
||||||
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -30,6 +27,7 @@ from vllm.model_executor.models.utils import (
|
|||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
_merge_multimodal_embeddings,
|
_merge_multimodal_embeddings,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
|
init_vllm_registered_model,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@ -54,6 +52,15 @@ from vllm.model_executor.models.interfaces import (
|
|||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.siglip2navit import Siglip2Encoder
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
|
||||||
# ===== TensorStream Compatibility Layer for Isaac MRoPE =====
|
# ===== TensorStream Compatibility Layer for Isaac MRoPE =====
|
||||||
# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding
|
# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding
|
||||||
|
|
||||||
@ -316,9 +323,10 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
|||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.patch_embedding = nn.Linear(
|
self.patch_embedding = ReplicatedLinear(
|
||||||
in_features=config.num_channels * self.patch_size * self.patch_size,
|
input_size=config.num_channels * self.patch_size * self.patch_size,
|
||||||
out_features=self.embed_dim,
|
output_size=self.embed_dim,
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_patches = config.num_patches
|
self.num_patches = config.num_patches
|
||||||
@ -1058,37 +1066,10 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
class Siglip2VisionTransformer(nn.Module):
|
||||||
default_weight_loader,
|
|
||||||
maybe_remap_kv_scale_name,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
|
||||||
from vllm.model_executor.models.siglip2navit import Siglip2VisionEmbeddings, Siglip2Encoder
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
|
|
||||||
class Siglip2VisionTransformer(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
|
||||||
):
|
|
||||||
|
|
||||||
is_pooling_model = True
|
|
||||||
|
|
||||||
merge_by_field_config = True
|
|
||||||
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"qkv_proj": [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
],
|
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config: PixelShuffleSiglip2VisionConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
@ -1151,64 +1132,28 @@ class Siglip2VisionTransformer(nn.Module, SupportsMultiModal, SupportsLoRA, Supp
|
|||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if self.quant_config is not None and (
|
|
||||||
scale_name := self.quant_config.get_cache_scale(name)
|
|
||||||
):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
loaded_weight = (
|
|
||||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
|
||||||
)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if name.endswith("scale"):
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = param.weight_loader
|
||||||
if weight_loader == default_weight_loader:
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
else:
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
print(f"qwen2: name={name}")
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
IsaacMultiModalProcessor,
|
IsaacMultiModalProcessor,
|
||||||
info=IsaacProcessingInfo,
|
info=IsaacProcessingInfo,
|
||||||
@ -1217,6 +1162,7 @@ class Siglip2VisionTransformer(nn.Module, SupportsMultiModal, SupportsLoRA, Supp
|
|||||||
class IsaacForConditionalGeneration(
|
class IsaacForConditionalGeneration(
|
||||||
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -1230,7 +1176,7 @@ class IsaacForConditionalGeneration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
supports_encoder_tp_data = True
|
supports_encoder_tp_data = True
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -1261,14 +1207,14 @@ class IsaacForConditionalGeneration(
|
|||||||
|
|
||||||
# Initialize the parent class with updated config
|
# Initialize the parent class with updated config
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
# Create the language model module to match checkpoint structure
|
# Create the language model module to match checkpoint structure
|
||||||
self.language_model = nn.ModuleDict({
|
self.language_model = nn.ModuleDict({
|
||||||
"embed_tokens": self.model.embed_tokens,
|
"embed_tokens": self.model.embed_tokens,
|
||||||
"layers": self.model.layers,
|
"layers": self.model.layers,
|
||||||
"norm": self.model.norm
|
"norm": self.model.norm
|
||||||
})
|
})
|
||||||
|
|
||||||
config.vision_config.preserve_original_pe = True
|
config.vision_config.preserve_original_pe = True
|
||||||
config.vision_config.use_rope = False
|
config.vision_config.use_rope = False
|
||||||
config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor
|
config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor
|
||||||
@ -1431,61 +1377,9 @@ class IsaacForConditionalGeneration(
|
|||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def merge_qkv_weights(
|
|
||||||
weights: Iterable[tuple[str, torch.Tensor]]
|
|
||||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
||||||
"""Merge separate Q, K, V projection weights into QKV format."""
|
|
||||||
|
|
||||||
# Buffer to collect q, k, v weights for each layer
|
|
||||||
qkv_buffer = {}
|
|
||||||
|
|
||||||
for name, tensor in weights:
|
|
||||||
# Check if this is a q/k/v projection weight
|
|
||||||
if '.q_proj.' in name or '.k_proj.' in name or '.v_proj.' in name:
|
|
||||||
# Extract the base name (everything before q/k/v_proj)
|
|
||||||
if '.q_proj.' in name:
|
|
||||||
base_name = name.replace('.q_proj.', '.qkv_proj.')
|
|
||||||
proj_type = 'q'
|
|
||||||
elif '.k_proj.' in name:
|
|
||||||
base_name = name.replace('.k_proj.', '.qkv_proj.')
|
|
||||||
proj_type = 'k'
|
|
||||||
else: # v_proj
|
|
||||||
base_name = name.replace('.v_proj.', '.qkv_proj.')
|
|
||||||
proj_type = 'v'
|
|
||||||
|
|
||||||
# Store in buffer
|
|
||||||
if base_name not in qkv_buffer:
|
|
||||||
qkv_buffer[base_name] = {}
|
|
||||||
qkv_buffer[base_name][proj_type] = tensor
|
|
||||||
|
|
||||||
# If we have all three (q, k, v), merge and yield
|
|
||||||
if len(qkv_buffer[base_name]) == 3:
|
|
||||||
q = qkv_buffer[base_name]['q']
|
|
||||||
k = qkv_buffer[base_name]['k']
|
|
||||||
v = qkv_buffer[base_name]['v']
|
|
||||||
|
|
||||||
# Concatenate along dim 0 for weight, dim agnostic for bias
|
|
||||||
merged = torch.cat([q, k, v], dim=0)
|
|
||||||
yield base_name, merged
|
|
||||||
|
|
||||||
# Clear buffer
|
|
||||||
del qkv_buffer[base_name]
|
|
||||||
else:
|
|
||||||
# Pass through non-qkv weights unchanged
|
|
||||||
yield name, tensor
|
|
||||||
|
|
||||||
# Check if any incomplete qkv sets remain
|
|
||||||
if qkv_buffer:
|
|
||||||
raise ValueError(f"Incomplete QKV weights found: {list(qkv_buffer.keys())}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
skip_prefixes = []
|
skip_prefixes = []
|
||||||
#if self.vision_embedding is None:
|
|
||||||
# skip_prefixes.extend(["vision_embedding."])
|
|
||||||
|
|
||||||
# Usage:
|
|
||||||
#weights = self.merge_qkv_weights(weights)
|
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user