diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 6b0ceaf21951..b60fefdda279 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -538,7 +538,7 @@ Specified using `--task generate`. | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | -| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5346d67b10d1..e6410ab068d2 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -227,6 +227,7 @@ MULTIMODAL_MODELS = { "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), + "AIDC-AI/Ovis2-1B": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index aefd6c973755..2e2a18abd03d 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -2,16 +2,23 @@ # A modified implementation of the AIMv2 Transformer # inserted here also the image tokenizer used by Ovis2 +from collections.abc import Iterable from typing import Optional import torch import torch.nn as nn -from torch.nn import functional as F +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.utils import divide +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.ovis import AIMv2Config @@ -24,29 +31,27 @@ class AIMv2SwiGLUFFN(nn.Module): in_features = config.hidden_size bias = config.use_bias - # TODO(Isotr0py): investigate if we can add TP to visual tokenizer - self.fc1 = ReplicatedLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = ReplicatedLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc2") - self.fc3 = ReplicatedLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc3") + self.fc13 = MergedColumnParallelLinear( + in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + ) + self.fc2 = RowParallelLinear( + input_size=hidden_features, + output_size=in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: - x_parallel, _ = self.fc1(x) - gate, _ = self.fc3(x) - x_parallel = F.silu(x_parallel) * gate - out, _ = self.fc2(x_parallel) - return out + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x class AIMv2PatchEmbed(nn.Module): @@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module): def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): super().__init__() - dim = config.hidden_size - - # TODO(Isotr0py): investigate if we can add TP to visual tokenizer + self.config = config + self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias) - # self.qkv = QKVParallelLinear( - # hidden_size=dim, - # head_size=dim // config.num_attention_heads, - # total_num_heads=config.num_attention_heads, - # bias=config.qkv_bias, - # quant_config=quant_config, - # prefix=f"{prefix}.qkv") - self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias) - # self.proj = RowParallelLinear(input_size=dim, - # output_size=dim, - # bias = config.use_bias, - # quant_config=quant_config, - # prefix=f"{prefix}.proj") + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 - def forward( # todo might implement multiple attn implementations - self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - B, N, C = x.shape + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward(self, x: torch.Tensor) -> torch.Tensor: qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) - qkv = qkv.reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - - q, k, v = qkv.unbind(0) - - x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) - x = x.transpose(1, 2).contiguous().reshape(B, N, C) + x = self.attn(q, k, v) x, _ = self.proj(x) return x @@ -141,37 +152,40 @@ class AIMv2Block(nn.Module): prefix=f"{prefix}.mlp") self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward(self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm_1.forward_native(x), mask) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm_1.forward_native(x)) x = x + self.mlp(self.norm_2.forward_native(x)) return x class AIMv2Transformer(nn.Module): - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ): super().__init__() self.blocks = nn.ModuleList([ AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") for i in range(config.num_hidden_layers) ]) - self.post_trunk_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + if require_post_norm: + self.post_trunk_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None - def forward( - self, - tokens: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, tokens: torch.Tensor) -> torch.Tensor: # they take the -1 as the ref embeddings, like a clip skip for block in self.blocks: - tokens = block(tokens, mask) - # NO NORM IN THE OG IMPLEMENTATION - # tokens = self.post_trunk_norm(tokens) + tokens = block(tokens) + if self.post_trunk_norm is not None: + tokens = self.post_trunk_norm(tokens) return tokens @@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module): def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, prefix: str = ""): super().__init__() self.preprocessor = AIMv2ViTPreprocessor(config) self.trunk = AIMv2Transformer(config, quant_config=quant_config, + require_post_norm=require_post_norm, prefix=f"{prefix}.trunk") - def forward( - self, - pixel_values: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: x = self.preprocessor(pixel_values) - x = self.trunk(x, mask) + x = self.trunk(x) return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".fc13", ".fc1", 0), + (".fc13", ".fc3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if (name.startswith("trunk.post_trunk_norm") + and self.trunk.post_trunk_norm is None): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index e8f3ae2156e0..9fd528fd7977 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -106,7 +106,6 @@ class CLIPAttention(nn.Module): f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, @@ -129,10 +128,6 @@ class CLIPAttention(nn.Module): self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index e03705d48f3e..232a63c50689 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -30,6 +30,9 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, @@ -48,7 +51,7 @@ from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig, OvisConfig) from vllm.transformers_utils.processors.ovis import OvisProcessor -from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. @@ -106,12 +109,14 @@ class VisualTokenizer(torch.nn.Module): config: BaseVisualTokenizerConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> nn.Module: model_type = config.backbone_config.model_type if model_type == "aimv2": + # No post rms_norm in Ovis2's AIMv2 ViT. return AIMv2Model( config=config.backbone_config, quant_config=quant_config, + require_post_norm=False, prefix=prefix, ) elif model_type == "siglip_vision_model": @@ -124,14 +129,14 @@ class VisualTokenizer(torch.nn.Module): f"Unsupported visual tokenizer model_type: {model_type}") @property - def dtype(self): + def dtype(self) -> torch.dtype: return next(self.head.parameters()).dtype @property - def device(self): + def device(self) -> torch.device: return next(self.head.parameters()).device - def tokenize(self, logits): + def tokenize(self, logits: torch.Tensor) -> torch.Tensor: if self.config.tokenize_function == 'softmax': tokens = softmax(logits, dim=-1) elif self.config.tokenize_function == 'gumbel_argmax': @@ -144,7 +149,7 @@ class VisualTokenizer(torch.nn.Module): f'or st_argmax, but got {self.config.tokenize_function}') return tokens - def encode(self, pixel_values): + def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: features = self.backbone(pixel_values) if self.config.drop_cls_token: features = features[:, 1:, :] @@ -395,7 +400,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, info=OvisProcessingInfo, dummy_inputs=OvisDummyInputsBuilder) -class Ovis(nn.Module, SupportsMultiModal): +class Ovis(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -410,7 +415,7 @@ class Ovis(nn.Module, SupportsMultiModal): self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, - quant_config=quant_config, + quant_config=self._maybe_ignore_quant_config(quant_config), prefix=f"{prefix}.visual_tokenizer", ) @@ -421,9 +426,16 @@ class Ovis(nn.Module, SupportsMultiModal): text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] - # TODO(Isotr0py): PP support - # self.make_empty_intermediate_tensors = ( - # self.language_model.make_empty_intermediate_tensors) + self.make_empty_intermediate_tensors = ( + self.get_language_model().make_empty_intermediate_tensors) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[OvisImagePatchInputs]: