mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 06:47:03 +08:00
fix cr
Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
parent
4cdd788dd0
commit
c7c3853e9e
@ -18,11 +18,8 @@ from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfi
|
||||
from transformers.tokenization_utils import TensorType
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -51,7 +48,6 @@ from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@ -97,11 +93,15 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
|
||||
|
||||
def create_cumulative_seq_lengths(
|
||||
seq_sizes: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Create cumulative sequence lengths for variable-length attention."""
|
||||
cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens[1:] = seq_sizes.cumsum(0)
|
||||
max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0
|
||||
max_seqlen = (
|
||||
seq_sizes.max()
|
||||
if len(seq_sizes) > 0
|
||||
else torch.tensor(0, dtype=torch.int32, device=device)
|
||||
)
|
||||
return cu_seqlens, max_seqlen
|
||||
|
||||
|
||||
@ -763,9 +763,6 @@ class IsaacImageProcessor:
|
||||
class IsaacProcessor:
|
||||
"""Processor wrapper (tokenizer + IsaacImageProcessor)."""
|
||||
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||
self.image_token = kwargs.pop("image_token", "<image>")
|
||||
self.image_processor = image_processor or IsaacImageProcessor(kwargs)
|
||||
@ -963,24 +960,6 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
|
||||
]
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(
|
||||
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
|
||||
)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||
]
|
||||
return torch.cat(ordered_tensors, dim=-1)
|
||||
|
||||
|
||||
class Siglip2VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -988,12 +967,15 @@ class Siglip2VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.tp_size = (
|
||||
1
|
||||
if use_data_parallel
|
||||
@ -1025,26 +1007,12 @@ class Siglip2VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -1064,7 +1032,6 @@ class Siglip2VisionAttention(nn.Module):
|
||||
*,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, _ = hidden_states.shape
|
||||
if batch_size != 1:
|
||||
@ -1074,45 +1041,14 @@ class Siglip2VisionAttention(nn.Module):
|
||||
q, k, v = self.split_qkv(x)
|
||||
q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v))
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(t, "b s ... -> (b s) ...") for t in (q, k, v))
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
)
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(tensor, "b s h d -> b h s d")
|
||||
for tensor in (q_i, k_i, v_i)
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
output = rearrange(output, "s b d -> b s d")
|
||||
@ -1126,9 +1062,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_data_parallel: bool = False,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -1137,9 +1071,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -1154,7 +1086,6 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
*,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
@ -1163,7 +1094,6 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -1182,36 +1112,17 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Isaac vision embedding does not support {self.attn_backend} backend."
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Siglip2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
use_data_parallel=use_data_parallel,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@ -1223,7 +1134,6 @@ class Siglip2Encoder(nn.Module):
|
||||
*,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
@ -1231,7 +1141,6 @@ class Siglip2Encoder(nn.Module):
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -1242,8 +1151,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
config: PixelShuffleSiglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -1256,22 +1164,10 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.encoder.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
|
||||
@ -1291,14 +1187,14 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
# Add a pseudo batch dimension for the encoder
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = create_cumulative_seq_lengths(seq_sizes, hidden_states.device)
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
cu_seqlens, max_seqlen = create_cumulative_seq_lengths(
|
||||
seq_sizes, hidden_states.device
|
||||
)
|
||||
|
||||
hidden_states = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
@ -1349,18 +1245,22 @@ class IsaacVisionEmbedding(nn.Module):
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.transformer = Siglip2VisionTransformer(
|
||||
vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding")
|
||||
vision_cfg,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "0"),
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.linear_fc1 = ColumnParallelLinear(
|
||||
hidden_dim,
|
||||
4 * hidden_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_embedding.1"),
|
||||
prefix=maybe_prefix(prefix, "1"),
|
||||
return_bias=False,
|
||||
)
|
||||
self.act = nn.SiLU()
|
||||
@ -1369,7 +1269,7 @@ class IsaacVisionEmbedding(nn.Module):
|
||||
output_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_embedding.3"),
|
||||
prefix=maybe_prefix(prefix, "3"),
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
@ -1457,11 +1357,6 @@ class IsaacForConditionalGeneration(
|
||||
vision_cfg = config.vision_config
|
||||
if vision_cfg is None:
|
||||
raise ValueError("IsaacConfig should always have vision_config")
|
||||
vision_cfg.preserve_original_pe = True
|
||||
vision_cfg.use_rope = False
|
||||
vision_cfg.hidden_stride = vision_cfg.pixel_shuffle_scale_factor
|
||||
vision_cfg.window_size = 32 * 2
|
||||
vision_cfg.fullatt_block_indexes = None
|
||||
attn_impl = (
|
||||
config.vision_attn_implementation
|
||||
if config.vision_attn_implementation is not None
|
||||
@ -1476,6 +1371,7 @@ class IsaacForConditionalGeneration(
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_embedding"),
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user