mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 16:42:28 +08:00
[Feature] Enhance Isaac model with vision embedding and attention mechanisms
Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
parent
2c13695951
commit
ac8a0b936a
@ -14,15 +14,30 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import PretrainedConfig, Qwen3Config
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
||||
from transformers.tokenization_utils import TensorType
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_xformers_attn_wrapper,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@ -36,13 +51,14 @@ from vllm.model_executor.models.interfaces import (
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from vllm.model_executor.models.siglip2navit import Siglip2Encoder
|
||||
from vllm.model_executor.models.siglip import SiglipMLP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@ -332,6 +348,16 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
||||
self.num_patches = num_patches
|
||||
|
||||
|
||||
def create_cumulative_seq_lengths(
|
||||
seq_sizes: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Create cumulative sequence lengths for variable-length attention."""
|
||||
cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens[1:] = seq_sizes.cumsum(0)
|
||||
max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0
|
||||
return cu_seqlens, max_seqlen
|
||||
|
||||
|
||||
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||
super().__init__()
|
||||
@ -367,7 +393,7 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
align_corners = False
|
||||
antialias = True
|
||||
for spatial_shape in spatial_shapes:
|
||||
height, width = spatial_shape
|
||||
height, width = int(spatial_shape[0]), int(spatial_shape[1])
|
||||
# Guard to ensure height and width are positive for torch.compile
|
||||
if height > 0 and width > 0:
|
||||
resized_pos_embed = F.interpolate(
|
||||
@ -399,21 +425,16 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||
):
|
||||
seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches
|
||||
|
||||
# Apply patch embeddings
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype))
|
||||
target_weight = self.patch_embedding.weight
|
||||
seq_patches = seq_patches.to(
|
||||
device=target_weight.device, dtype=target_weight.dtype
|
||||
)
|
||||
patch_embeds = self.patch_embedding(seq_patches)
|
||||
pos_embeds = self.positional_embeddings(packed_seq_patches)
|
||||
|
||||
# Flatten patch embeddings to match positional embeddings format
|
||||
batch_size, patches_per_image, embed_dim = patch_embeds.shape
|
||||
|
||||
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
|
||||
if batch_size != 1:
|
||||
raise ValueError(
|
||||
"Variable-length attention expects batch_size=1 for packed sequences"
|
||||
)
|
||||
|
||||
patch_embeds = patch_embeds.view(batch_size * patches_per_image, embed_dim)
|
||||
if patch_embeds.dim() == 3:
|
||||
patch_embeds = patch_embeds.view(-1, patch_embeds.size(-1))
|
||||
|
||||
# Add positional embeddings to patch embeddings
|
||||
embeddings = patch_embeds + pos_embeds
|
||||
@ -1162,6 +1183,313 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
||||
]
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(
|
||||
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
|
||||
)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||
]
|
||||
return torch.cat(ordered_tensors, dim=-1)
|
||||
|
||||
|
||||
class Siglip2VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixelShuffleSiglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend: AttentionBackendEnum | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
else parallel_state.get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
config.hidden_size, config.num_attention_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
config.num_attention_heads, self.tp_size
|
||||
)
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=config.num_attention_heads,
|
||||
total_num_kv_heads=config.num_attention_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size)
|
||||
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
if self.tp_size > 1:
|
||||
q = dist_utils.split_tensor_along_last_dim(q, self.tp_size)[self.tp_rank]
|
||||
k = dist_utils.split_tensor_along_last_dim(k, self.tp_size)[self.tp_rank]
|
||||
v = dist_utils.split_tensor_along_last_dim(v, self.tp_size)[self.tp_rank]
|
||||
|
||||
new_shape = (
|
||||
seq_len,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, _ = hidden_states.shape
|
||||
if batch_size != 1:
|
||||
raise ValueError("packed variable-length attention expects batch_size=1")
|
||||
x = rearrange(hidden_states, "b s d -> s b d")
|
||||
x, _ = self.qkv_proj(x)
|
||||
q, k, v = self.split_qkv(x)
|
||||
q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v))
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(t, "b s ... -> (b s) ...") for t in (q, k, v))
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(tensor, "b s h d -> b h s d")
|
||||
for tensor in (q_i, k_i, v_i)
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
if seqlens is None:
|
||||
raise ValueError("xFormers attention backend requires seqlens tensor.")
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
output = rearrange(output, "s b d -> b s d")
|
||||
return output
|
||||
|
||||
|
||||
class Siglip2EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixelShuffleSiglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_upstream_fa: bool = False,
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.self_attn = Siglip2VisionAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Siglip2Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixelShuffleSiglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Siglip2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
use_upstream_fa=self.use_upstream_fa,
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
*,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Siglip2VisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1169,7 +1497,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -1187,6 +1515,19 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.encoder.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.encoder.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
|
||||
@ -1203,15 +1544,20 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
# Get embeddings from packed sequence
|
||||
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
|
||||
|
||||
grid_thws = torch.tensor(
|
||||
[[1, token_grids[0][0].item(), token_grids[0][1].item()]]
|
||||
)
|
||||
last_hidden_state = self.encoder(hidden_states, grid_thws)
|
||||
hidden_states = self.post_layernorm(last_hidden_state)
|
||||
|
||||
# Add a pseudo batch dimension for the encoder
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = create_cumulative_seq_lengths(seq_sizes, hidden_states.device)
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
hidden_states = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
if self.pixel_shuffle_scale_factor > 1:
|
||||
hidden_states = pixel_shuffle_varlen(
|
||||
x=hidden_states,
|
||||
@ -1252,6 +1598,44 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class IsaacVisionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vision_cfg: PixelShuffleSiglip2VisionConfig,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
prefix: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformer = Siglip2VisionTransformer(
|
||||
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
|
||||
)
|
||||
self.linear_fc1 = ColumnParallelLinear(
|
||||
hidden_dim,
|
||||
4 * hidden_dim,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "vision_embedding.1"),
|
||||
return_bias=False,
|
||||
)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_fc2 = RowParallelLinear(
|
||||
4 * hidden_dim,
|
||||
output_dim,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "vision_embedding.3"),
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(packed_seq_patches)
|
||||
hidden_states = self.linear_fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
IsaacMultiModalProcessor,
|
||||
info=IsaacProcessingInfo,
|
||||
@ -1278,6 +1662,10 @@ class IsaacForConditionalGeneration(
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embedding.": "vision_embedding.",
|
||||
"vision_embedding.0": "vision_embedding.transformer",
|
||||
"vision_embedding.1": "vision_embedding.linear_fc1",
|
||||
"vision_embedding.2": "vision_embedding.act",
|
||||
"vision_embedding.3": "vision_embedding.linear_fc2",
|
||||
}
|
||||
)
|
||||
|
||||
@ -1325,17 +1713,11 @@ class IsaacForConditionalGeneration(
|
||||
raise ValueError("IsaacConfig should always have vision_config")
|
||||
|
||||
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
|
||||
self.vision_embedding = nn.Sequential(
|
||||
Siglip2VisionTransformer(
|
||||
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
|
||||
),
|
||||
nn.Linear(
|
||||
hidden_dim,
|
||||
4 * hidden_dim,
|
||||
bias=False,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(4 * hidden_dim, config.hidden_size, bias=False),
|
||||
self.vision_embedding = IsaacVisionEmbedding(
|
||||
vision_cfg=vision_cfg,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=config.hidden_size,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def get_mrope_input_positions(
|
||||
@ -1502,6 +1884,6 @@ class IsaacForConditionalGeneration(
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="vision_embedding.3", # The final linear layer
|
||||
connector="vision_embedding.linear_fc2", # The final linear layer
|
||||
tower_model="vision_embedding",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user