mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 13:37:06 +08:00
Updated to use Siglip2Encoder defined in siglip2navit.py.
Signed-off-by: Oscar Gonzalez <ogonzal6@alumni.jh.edu>
This commit is contained in:
parent
e27cb3c53d
commit
37a92d952b
@ -29,6 +29,7 @@ from vllm.model_executor.models.utils import (
|
||||
WeightsMapper,
|
||||
AutoWeightsLoader,
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -308,14 +309,6 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
||||
self.num_patches = num_patches
|
||||
|
||||
|
||||
def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]:
|
||||
"""Create cumulative sequence lengths for variable-length attention."""
|
||||
cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens[1:] = seq_sizes.cumsum(0)
|
||||
max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0
|
||||
return cu_seqlens, max_seqlen
|
||||
|
||||
|
||||
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||
super().__init__()
|
||||
@ -380,7 +373,6 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
pos_embeds = self.positional_embeddings(packed_seq_patches)
|
||||
|
||||
# Flatten patch embeddings to match positional embeddings format
|
||||
# From [batch, patches_per_image, embed_dim] to [total_patches, embed_dim]
|
||||
batch_size, patches_per_image, embed_dim = patch_embeds.shape
|
||||
|
||||
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
|
||||
@ -394,158 +386,6 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class Siglip2VariableLengthAttention(nn.Module):
|
||||
"""Custom attention that supports variable-length sequences with flash attention."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
|
||||
batch_size, seq_len, _ = hidden_states.size()
|
||||
|
||||
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
|
||||
if batch_size != 1:
|
||||
raise ValueError("Variable-length attention expects batch_size=1 for packed sequences")
|
||||
hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim)
|
||||
|
||||
# Store original dtype
|
||||
orig_dtype = hidden_states.dtype
|
||||
|
||||
# 1. Linear projections
|
||||
Q = self.q_proj(hidden_states) # (seq_len, embed_dim)
|
||||
K = self.k_proj(hidden_states) # (seq_len, embed_dim)
|
||||
V = self.v_proj(hidden_states) # (seq_len, embed_dim)
|
||||
|
||||
# 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim)
|
||||
Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads)
|
||||
K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads)
|
||||
V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads)
|
||||
|
||||
# 3. Apply variable-length attention using flash attention
|
||||
attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward(
|
||||
query=Q,
|
||||
key=K,
|
||||
value=V,
|
||||
cum_seq_q=cu_seqlens,
|
||||
cum_seq_k=cu_seqlens,
|
||||
max_q=max_seqlen,
|
||||
max_k=max_seqlen,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=False,
|
||||
return_debug_mask=False,
|
||||
scale=self.scale,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
|
||||
# 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim)
|
||||
attn_output = attn_output.reshape(seq_len, self.embed_dim)
|
||||
|
||||
# 5. Convert back to original dtype if needed
|
||||
if attn_output.dtype != orig_dtype:
|
||||
attn_output = attn_output.to(orig_dtype)
|
||||
|
||||
# 6. Project output
|
||||
attn_output = self.out_proj(attn_output) # (seq_len, embed_dim)
|
||||
|
||||
# 7. Add back batch dimension for compatibility
|
||||
attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class IsaacSiglip2EncoderLayer(nn.Module):
|
||||
"""Siglip2 encoder layer with variable-length attention."""
|
||||
|
||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Siglip2VariableLengthAttention(config)
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
max_seqlen: int = None,
|
||||
) -> tuple[torch.FloatTensor]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class IsaacEncoder(nn.Module):
|
||||
"""Encoder using Isaac encoder layers with variable-length attention support."""
|
||||
|
||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: int | None = None,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return hidden_states, all_hidden_states, None
|
||||
|
||||
|
||||
def create_pixel_shuffle_index_map(
|
||||
seq_sizes: torch.Tensor,
|
||||
token_grids: torch.Tensor,
|
||||
@ -669,52 +509,6 @@ def pixel_shuffle_varlen(
|
||||
out = out.unsqueeze(0)
|
||||
return out
|
||||
|
||||
|
||||
class Siglip2SequenceVisionTransformer(nn.Module):
|
||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = Siglip2VariableSequenceEmbeddings(config)
|
||||
self.encoder = IsaacEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
|
||||
|
||||
def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]):
|
||||
seq_patches, token_grids = packed_seq_patches
|
||||
seq_sizes = torch.prod(token_grids, dim=-1)
|
||||
|
||||
# Get embeddings from packed sequence
|
||||
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
|
||||
|
||||
# Add a pseudo batch dimension for the encoder
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
# Generate cumulative sequence lengths for variable-length attention
|
||||
cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device)
|
||||
|
||||
# Pass through encoder with variable-length attention parameters
|
||||
hidden_states, _, _ = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
# Apply final layer normalization
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
if self.pixel_shuffle_scale_factor > 1:
|
||||
hidden_states = pixel_shuffle_varlen(
|
||||
x=hidden_states,
|
||||
token_grids=token_grids,
|
||||
scale_factor=self.pixel_shuffle_scale_factor,
|
||||
)
|
||||
# Remove the pseudo batch dimension we added earlier
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
# Return the full sequence of embeddings
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
@ -1009,7 +803,6 @@ class IsaacImageProcessorKwargs(TypedDict, total=False):
|
||||
max_num_patches: int
|
||||
min_num_patches: int
|
||||
pixel_shuffle_scale: int
|
||||
#merge_size: int # kept for parity with other processors that expose it
|
||||
|
||||
|
||||
class IsaacImageProcessor:
|
||||
@ -1265,6 +1058,156 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.model_executor.models.siglip2navit import Siglip2VisionEmbeddings, Siglip2Encoder
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
class Siglip2VisionTransformer(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = Siglip2VariableSequenceEmbeddings(config)
|
||||
self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the spatial dimensions (height, width)
|
||||
of the input images.
|
||||
"""
|
||||
|
||||
seq_patches, token_grids = packed_seq_patches
|
||||
seq_sizes = torch.prod(token_grids, dim=-1)
|
||||
|
||||
# Get embeddings from packed sequence
|
||||
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
|
||||
|
||||
grid_thws = torch.tensor([[1, token_grids[0][0].item(), token_grids[0][1].item()]])
|
||||
last_hidden_state = self.encoder(hidden_states, grid_thws)
|
||||
hidden_states = self.post_layernorm(last_hidden_state)
|
||||
|
||||
# Add a pseudo batch dimension for the encoder
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
if self.pixel_shuffle_scale_factor > 1:
|
||||
hidden_states = pixel_shuffle_varlen(
|
||||
x=hidden_states,
|
||||
token_grids=token_grids,
|
||||
scale_factor=self.pixel_shuffle_scale_factor,
|
||||
)
|
||||
# Remove the pseudo batch dimension we added earlier
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
#return last_hidden_state
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
print(f"qwen2: name={name}")
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
IsaacMultiModalProcessor,
|
||||
@ -1274,13 +1217,24 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class IsaacForConditionalGeneration(
|
||||
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embedding.": "vision_embedding.",
|
||||
"model.vision_embedding.": "vision_embedding.",
|
||||
}
|
||||
)
|
||||
|
||||
@ -1315,13 +1269,18 @@ class IsaacForConditionalGeneration(
|
||||
"norm": self.model.norm
|
||||
})
|
||||
|
||||
config.vision_config.preserve_original_pe = True
|
||||
config.vision_config.use_rope = False
|
||||
config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor
|
||||
config.vision_config.window_size = 32*2
|
||||
config.vision_config.fullatt_block_indexes = None
|
||||
vision_cfg = config.vision_config
|
||||
if vision_cfg is None:
|
||||
raise ValueError("IsaacConfig should always have vision_config")
|
||||
|
||||
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
|
||||
self.vision_embedding = nn.Sequential(
|
||||
Siglip2SequenceVisionTransformer(vision_cfg),
|
||||
Siglip2VisionTransformer(vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")),
|
||||
nn.Linear(
|
||||
hidden_dim,
|
||||
4 * hidden_dim,
|
||||
@ -1472,10 +1431,61 @@ class IsaacForConditionalGeneration(
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def merge_qkv_weights(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Merge separate Q, K, V projection weights into QKV format."""
|
||||
|
||||
# Buffer to collect q, k, v weights for each layer
|
||||
qkv_buffer = {}
|
||||
|
||||
for name, tensor in weights:
|
||||
# Check if this is a q/k/v projection weight
|
||||
if '.q_proj.' in name or '.k_proj.' in name or '.v_proj.' in name:
|
||||
# Extract the base name (everything before q/k/v_proj)
|
||||
if '.q_proj.' in name:
|
||||
base_name = name.replace('.q_proj.', '.qkv_proj.')
|
||||
proj_type = 'q'
|
||||
elif '.k_proj.' in name:
|
||||
base_name = name.replace('.k_proj.', '.qkv_proj.')
|
||||
proj_type = 'k'
|
||||
else: # v_proj
|
||||
base_name = name.replace('.v_proj.', '.qkv_proj.')
|
||||
proj_type = 'v'
|
||||
|
||||
# Store in buffer
|
||||
if base_name not in qkv_buffer:
|
||||
qkv_buffer[base_name] = {}
|
||||
qkv_buffer[base_name][proj_type] = tensor
|
||||
|
||||
# If we have all three (q, k, v), merge and yield
|
||||
if len(qkv_buffer[base_name]) == 3:
|
||||
q = qkv_buffer[base_name]['q']
|
||||
k = qkv_buffer[base_name]['k']
|
||||
v = qkv_buffer[base_name]['v']
|
||||
|
||||
# Concatenate along dim 0 for weight, dim agnostic for bias
|
||||
merged = torch.cat([q, k, v], dim=0)
|
||||
yield base_name, merged
|
||||
|
||||
# Clear buffer
|
||||
del qkv_buffer[base_name]
|
||||
else:
|
||||
# Pass through non-qkv weights unchanged
|
||||
yield name, tensor
|
||||
|
||||
# Check if any incomplete qkv sets remain
|
||||
if qkv_buffer:
|
||||
raise ValueError(f"Incomplete QKV weights found: {list(qkv_buffer.keys())}")
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
skip_prefixes = []
|
||||
if self.vision_embedding is None:
|
||||
skip_prefixes.extend(["vision_embedding."])
|
||||
#if self.vision_embedding is None:
|
||||
# skip_prefixes.extend(["vision_embedding."])
|
||||
|
||||
# Usage:
|
||||
#weights = self.merge_qkv_weights(weights)
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user