mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 02:09:07 +08:00
[Feature] Enhance Isaac model with vision embedding and attention mechanisms
Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
parent
2c13695951
commit
ac8a0b936a
@ -14,15 +14,30 @@ import PIL.Image
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
from transformers import PretrainedConfig, Qwen3Config
|
from transformers import PretrainedConfig, Qwen3Config
|
||||||
from transformers.image_processing_utils import BatchFeature
|
from transformers.image_processing_utils import BatchFeature
|
||||||
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
|
||||||
from transformers.tokenization_utils import TensorType
|
from transformers.tokenization_utils import TensorType
|
||||||
from typing_extensions import TypedDict, Unpack
|
from typing_extensions import TypedDict, Unpack
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
from vllm.attention.layer import (
|
||||||
|
check_upstream_fa_availability,
|
||||||
|
maybe_get_vit_flash_attn_backend,
|
||||||
|
)
|
||||||
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
|
vit_xformers_attn_wrapper,
|
||||||
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.distributed import parallel_state
|
||||||
|
from vllm.distributed import utils as dist_utils
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
@ -36,13 +51,14 @@ from vllm.model_executor.models.interfaces import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||||
from vllm.model_executor.models.siglip2navit import Siglip2Encoder
|
from vllm.model_executor.models.siglip import SiglipMLP
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
_merge_multimodal_embeddings,
|
_merge_multimodal_embeddings,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
@ -332,6 +348,16 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
|||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
|
|
||||||
|
def create_cumulative_seq_lengths(
|
||||||
|
seq_sizes: torch.Tensor, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""Create cumulative sequence lengths for variable-length attention."""
|
||||||
|
cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device)
|
||||||
|
cu_seqlens[1:] = seq_sizes.cumsum(0)
|
||||||
|
max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0
|
||||||
|
return cu_seqlens, max_seqlen
|
||||||
|
|
||||||
|
|
||||||
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
||||||
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -367,7 +393,7 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
|||||||
align_corners = False
|
align_corners = False
|
||||||
antialias = True
|
antialias = True
|
||||||
for spatial_shape in spatial_shapes:
|
for spatial_shape in spatial_shapes:
|
||||||
height, width = spatial_shape
|
height, width = int(spatial_shape[0]), int(spatial_shape[1])
|
||||||
# Guard to ensure height and width are positive for torch.compile
|
# Guard to ensure height and width are positive for torch.compile
|
||||||
if height > 0 and width > 0:
|
if height > 0 and width > 0:
|
||||||
resized_pos_embed = F.interpolate(
|
resized_pos_embed = F.interpolate(
|
||||||
@ -399,21 +425,16 @@ class Siglip2VariableSequenceEmbeddings(nn.Module):
|
|||||||
):
|
):
|
||||||
seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches
|
seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches
|
||||||
|
|
||||||
# Apply patch embeddings
|
target_weight = self.patch_embedding.weight
|
||||||
target_dtype = self.patch_embedding.weight.dtype
|
seq_patches = seq_patches.to(
|
||||||
patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype))
|
device=target_weight.device, dtype=target_weight.dtype
|
||||||
|
)
|
||||||
|
patch_embeds = self.patch_embedding(seq_patches)
|
||||||
pos_embeds = self.positional_embeddings(packed_seq_patches)
|
pos_embeds = self.positional_embeddings(packed_seq_patches)
|
||||||
|
|
||||||
# Flatten patch embeddings to match positional embeddings format
|
# Flatten patch embeddings to match positional embeddings format
|
||||||
batch_size, patches_per_image, embed_dim = patch_embeds.shape
|
if patch_embeds.dim() == 3:
|
||||||
|
patch_embeds = patch_embeds.view(-1, patch_embeds.size(-1))
|
||||||
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
|
|
||||||
if batch_size != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Variable-length attention expects batch_size=1 for packed sequences"
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_embeds = patch_embeds.view(batch_size * patches_per_image, embed_dim)
|
|
||||||
|
|
||||||
# Add positional embeddings to patch embeddings
|
# Add positional embeddings to patch embeddings
|
||||||
embeddings = patch_embeds + pos_embeds
|
embeddings = patch_embeds + pos_embeds
|
||||||
@ -1162,6 +1183,313 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int):
|
||||||
|
"""All-gather the input tensor interleavely across model parallel group."""
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||||
|
dist.all_gather(
|
||||||
|
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
|
||||||
|
)
|
||||||
|
|
||||||
|
gathered_tensors_split = [
|
||||||
|
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
|
||||||
|
]
|
||||||
|
ordered_tensors = [
|
||||||
|
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||||
|
]
|
||||||
|
return torch.cat(ordered_tensors, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2VisionAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PixelShuffleSiglip2VisionConfig,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
use_upstream_fa: bool = False,
|
||||||
|
attn_backend: AttentionBackendEnum | None = None,
|
||||||
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.tp_size = (
|
||||||
|
1
|
||||||
|
if use_data_parallel
|
||||||
|
else parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
|
config.hidden_size, config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
|
config.num_attention_heads, self.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
total_num_heads=config.num_attention_heads,
|
||||||
|
total_num_kv_heads=config.num_attention_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
input_size=config.hidden_size,
|
||||||
|
output_size=config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_upstream_fa = use_upstream_fa
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
if self.attn_backend not in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
if self.attn_backend not in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
|
AttentionBackendEnum.XFORMERS,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||||
|
)
|
||||||
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
|
maybe_get_vit_flash_attn_backend(
|
||||||
|
self.attn_backend,
|
||||||
|
self.use_upstream_fa,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
}
|
||||||
|
|
||||||
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
|
seq_len, bs, _ = qkv.shape
|
||||||
|
if self.tp_size > 1:
|
||||||
|
qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size)
|
||||||
|
|
||||||
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
q = dist_utils.split_tensor_along_last_dim(q, self.tp_size)[self.tp_rank]
|
||||||
|
k = dist_utils.split_tensor_along_last_dim(k, self.tp_size)[self.tp_rank]
|
||||||
|
v = dist_utils.split_tensor_along_last_dim(v, self.tp_size)[self.tp_rank]
|
||||||
|
|
||||||
|
new_shape = (
|
||||||
|
seq_len,
|
||||||
|
bs,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor | None,
|
||||||
|
seqlens: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, _ = hidden_states.shape
|
||||||
|
if batch_size != 1:
|
||||||
|
raise ValueError("packed variable-length attention expects batch_size=1")
|
||||||
|
x = rearrange(hidden_states, "b s d -> s b d")
|
||||||
|
x, _ = self.qkv_proj(x)
|
||||||
|
q, k, v = self.split_qkv(x)
|
||||||
|
q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v))
|
||||||
|
|
||||||
|
if self.is_flash_attn_backend:
|
||||||
|
q, k, v = (rearrange(t, "b s ... -> (b s) ...") for t in (q, k, v))
|
||||||
|
output = self.flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen,
|
||||||
|
max_seqlen_k=max_seqlen,
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
context_layer = rearrange(
|
||||||
|
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||||
|
).contiguous()
|
||||||
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
|
outputs = []
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
start_idx = cu_seqlens[i - 1]
|
||||||
|
end_idx = cu_seqlens[i]
|
||||||
|
q_i = q[:, start_idx:end_idx]
|
||||||
|
k_i = k[:, start_idx:end_idx]
|
||||||
|
v_i = v[:, start_idx:end_idx]
|
||||||
|
q_i, k_i, v_i = (
|
||||||
|
rearrange(tensor, "b s h d -> b h s d")
|
||||||
|
for tensor in (q_i, k_i, v_i)
|
||||||
|
)
|
||||||
|
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||||
|
output_i = rearrange(output_i, "b h s d -> b s h d")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = rearrange(
|
||||||
|
context_layer, "b s h d -> s b (h d)"
|
||||||
|
).contiguous()
|
||||||
|
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||||
|
if seqlens is None:
|
||||||
|
raise ValueError("xFormers attention backend requires seqlens tensor.")
|
||||||
|
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||||
|
)
|
||||||
|
|
||||||
|
output, _ = self.out_proj(context_layer)
|
||||||
|
output = rearrange(output, "s b d -> b s d")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2EncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PixelShuffleSiglip2VisionConfig,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
|
use_upstream_fa: bool = False,
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
self.self_attn = Siglip2VisionAttention(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
|
use_upstream_fa=use_upstream_fa,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = SiglipMLP(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: torch.Tensor | None,
|
||||||
|
seqlens: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PixelShuffleSiglip2VisionConfig,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
|
)
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend not in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
if self.attn_backend not in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
|
AttentionBackendEnum.XFORMERS,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Siglip2EncoderLayer(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
|
use_upstream_fa=self.use_upstream_fa,
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
|
)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
*,
|
||||||
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
max_seqlen: torch.Tensor | None = None,
|
||||||
|
seqlens: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Siglip2VisionTransformer(nn.Module):
|
class Siglip2VisionTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1169,7 +1497,7 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend_override: _Backend | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -1187,6 +1515,19 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def compute_attn_mask_seqlen(
|
||||||
|
self, cu_seqlens: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
max_seqlen, seqlens = None, None
|
||||||
|
if self.encoder.attn_backend in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
elif self.encoder.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||||
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
|
packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
|
||||||
@ -1203,15 +1544,20 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
# Get embeddings from packed sequence
|
# Get embeddings from packed sequence
|
||||||
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
|
hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))
|
||||||
|
|
||||||
grid_thws = torch.tensor(
|
|
||||||
[[1, token_grids[0][0].item(), token_grids[0][1].item()]]
|
|
||||||
)
|
|
||||||
last_hidden_state = self.encoder(hidden_states, grid_thws)
|
|
||||||
hidden_states = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
# Add a pseudo batch dimension for the encoder
|
# Add a pseudo batch dimension for the encoder
|
||||||
hidden_states = hidden_states.unsqueeze(0)
|
hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
|
||||||
|
cu_seqlens, _ = create_cumulative_seq_lengths(seq_sizes, hidden_states.device)
|
||||||
|
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
hidden_states = self.post_layernorm(hidden_states)
|
||||||
|
|
||||||
if self.pixel_shuffle_scale_factor > 1:
|
if self.pixel_shuffle_scale_factor > 1:
|
||||||
hidden_states = pixel_shuffle_varlen(
|
hidden_states = pixel_shuffle_varlen(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@ -1252,6 +1598,44 @@ class Siglip2VisionTransformer(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class IsaacVisionEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_cfg: PixelShuffleSiglip2VisionConfig,
|
||||||
|
hidden_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
prefix: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = Siglip2VisionTransformer(
|
||||||
|
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
|
||||||
|
)
|
||||||
|
self.linear_fc1 = ColumnParallelLinear(
|
||||||
|
hidden_dim,
|
||||||
|
4 * hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, "vision_embedding.1"),
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.linear_fc2 = RowParallelLinear(
|
||||||
|
4 * hidden_dim,
|
||||||
|
output_dim,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, "vision_embedding.3"),
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.transformer(packed_seq_patches)
|
||||||
|
hidden_states = self.linear_fc1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
IsaacMultiModalProcessor,
|
IsaacMultiModalProcessor,
|
||||||
info=IsaacProcessingInfo,
|
info=IsaacProcessingInfo,
|
||||||
@ -1278,6 +1662,10 @@ class IsaacForConditionalGeneration(
|
|||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
"model.vision_embedding.": "vision_embedding.",
|
"model.vision_embedding.": "vision_embedding.",
|
||||||
|
"vision_embedding.0": "vision_embedding.transformer",
|
||||||
|
"vision_embedding.1": "vision_embedding.linear_fc1",
|
||||||
|
"vision_embedding.2": "vision_embedding.act",
|
||||||
|
"vision_embedding.3": "vision_embedding.linear_fc2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1325,17 +1713,11 @@ class IsaacForConditionalGeneration(
|
|||||||
raise ValueError("IsaacConfig should always have vision_config")
|
raise ValueError("IsaacConfig should always have vision_config")
|
||||||
|
|
||||||
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
|
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
|
||||||
self.vision_embedding = nn.Sequential(
|
self.vision_embedding = IsaacVisionEmbedding(
|
||||||
Siglip2VisionTransformer(
|
vision_cfg=vision_cfg,
|
||||||
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
|
hidden_dim=hidden_dim,
|
||||||
),
|
output_dim=config.hidden_size,
|
||||||
nn.Linear(
|
prefix=prefix,
|
||||||
hidden_dim,
|
|
||||||
4 * hidden_dim,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(4 * hidden_dim, config.hidden_size, bias=False),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mrope_input_positions(
|
def get_mrope_input_positions(
|
||||||
@ -1502,6 +1884,6 @@ class IsaacForConditionalGeneration(
|
|||||||
"""
|
"""
|
||||||
return MultiModelKeys.from_string_field(
|
return MultiModelKeys.from_string_field(
|
||||||
language_model="language_model",
|
language_model="language_model",
|
||||||
connector="vision_embedding.3", # The final linear layer
|
connector="vision_embedding.linear_fc2", # The final linear layer
|
||||||
tower_model="vision_embedding",
|
tower_model="vision_embedding",
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user