diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 6b0ceaf21951..b60fefdda279 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -538,7 +538,7 @@ Specified using `--task generate`.
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
| `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
-| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ |
+| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py
index 5346d67b10d1..e6410ab068d2 100644
--- a/tests/distributed/test_pipeline_parallel.py
+++ b/tests/distributed/test_pipeline_parallel.py
@@ -227,6 +227,7 @@ MULTIMODAL_MODELS = {
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
+ "AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py
index aefd6c973755..2e2a18abd03d 100644
--- a/vllm/model_executor/models/aimv2.py
+++ b/vllm/model_executor/models/aimv2.py
@@ -2,16 +2,23 @@
# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
+from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
-from torch.nn import functional as F
+from vllm.attention.layer import MultiHeadAttention
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.distributed.utils import divide
+from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config
@@ -24,29 +31,27 @@ class AIMv2SwiGLUFFN(nn.Module):
in_features = config.hidden_size
bias = config.use_bias
- # TODO(Isotr0py): investigate if we can add TP to visual tokenizer
- self.fc1 = ReplicatedLinear(in_features,
- hidden_features,
- bias=bias,
- quant_config=quant_config,
- prefix=f"{prefix}.fc1")
- self.fc2 = ReplicatedLinear(hidden_features,
- in_features,
- bias=bias,
- quant_config=quant_config,
- prefix=f"{prefix}.fc2")
- self.fc3 = ReplicatedLinear(in_features,
- hidden_features,
- bias=bias,
- quant_config=quant_config,
- prefix=f"{prefix}.fc3")
+ self.fc13 = MergedColumnParallelLinear(
+ in_features,
+ [hidden_features] * 2,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc13",
+ )
+ self.fc2 = RowParallelLinear(
+ input_size=hidden_features,
+ output_size=in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ )
+ self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_parallel, _ = self.fc1(x)
- gate, _ = self.fc3(x)
- x_parallel = F.silu(x_parallel) * gate
- out, _ = self.fc2(x_parallel)
- return out
+ x, _ = self.fc13(x)
+ x = self.act_fn(x)
+ x, _ = self.fc2(x)
+ return x
class AIMv2PatchEmbed(nn.Module):
@@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
- dim = config.hidden_size
-
- # TODO(Isotr0py): investigate if we can add TP to visual tokenizer
+ self.config = config
+ self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
- self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
- # self.qkv = QKVParallelLinear(
- # hidden_size=dim,
- # head_size=dim // config.num_attention_heads,
- # total_num_heads=config.num_attention_heads,
- # bias=config.qkv_bias,
- # quant_config=quant_config,
- # prefix=f"{prefix}.qkv")
- self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
- # self.proj = RowParallelLinear(input_size=dim,
- # output_size=dim,
- # bias = config.use_bias,
- # quant_config=quant_config,
- # prefix=f"{prefix}.proj")
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ "embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
- def forward( # todo might implement multiple attn implementations
- self,
- x: torch.Tensor,
- mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- B, N, C = x.shape
+ self.qkv = QKVParallelLinear(
+ hidden_size=self.embed_dim,
+ head_size=self.head_dim,
+ total_num_heads=self.num_heads,
+ bias=config.qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ )
+
+ self.proj = RowParallelLinear(
+ input_size=self.embed_dim,
+ output_size=self.embed_dim,
+ bias=config.use_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.proj",
+ )
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
+
+ self.attn = MultiHeadAttention(self.num_heads_per_partition,
+ self.head_dim, self.scale)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(x)
+ q, k, v = qkv.chunk(3, dim=-1)
- qkv = qkv.reshape(B, N, 3, self.num_heads,
- C // self.num_heads).permute(2, 0, 3, 1, 4)
-
- q, k, v = qkv.unbind(0)
-
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
- x = x.transpose(1, 2).contiguous().reshape(B, N, C)
+ x = self.attn(q, k, v)
x, _ = self.proj(x)
return x
@@ -141,37 +152,40 @@ class AIMv2Block(nn.Module):
prefix=f"{prefix}.mlp")
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(self,
- x: torch.Tensor,
- mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- x = x + self.attn(self.norm_1.forward_native(x), mask)
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x + self.attn(self.norm_1.forward_native(x))
x = x + self.mlp(self.norm_2.forward_native(x))
return x
class AIMv2Transformer(nn.Module):
- def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
- prefix: str):
+ def __init__(
+ self,
+ config: AIMv2Config,
+ quant_config: QuantizationConfig,
+ *,
+ require_post_norm: Optional[bool] = None,
+ prefix: str = "",
+ ):
super().__init__()
self.blocks = nn.ModuleList([
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
])
- self.post_trunk_norm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps)
+ if require_post_norm:
+ self.post_trunk_norm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ else:
+ self.post_trunk_norm = None
- def forward(
- self,
- tokens: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# they take the -1 as the ref embeddings, like a clip skip
for block in self.blocks:
- tokens = block(tokens, mask)
- # NO NORM IN THE OG IMPLEMENTATION
- # tokens = self.post_trunk_norm(tokens)
+ tokens = block(tokens)
+ if self.post_trunk_norm is not None:
+ tokens = self.post_trunk_norm(tokens)
return tokens
@@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
def __init__(self,
config: AIMv2Config,
quant_config: QuantizationConfig,
+ *,
+ require_post_norm: Optional[bool] = None,
prefix: str = ""):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config,
quant_config=quant_config,
+ require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk")
- def forward(
- self,
- pixel_values: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
x = self.preprocessor(pixel_values)
- x = self.trunk(x, mask)
+ x = self.trunk(x)
return x
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".fc13", ".fc1", 0),
+ (".fc13", ".fc3", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ # post_layernorm is optional in SiglipVisionModel
+ if (name.startswith("trunk.post_trunk_norm")
+ and self.trunk.post_trunk_norm is None):
+ continue
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py
index e8f3ae2156e0..9fd528fd7977 100644
--- a/vllm/model_executor/models/clip.py
+++ b/vllm/model_executor/models/clip.py
@@ -106,7 +106,6 @@ class CLIPAttention(nn.Module):
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
@@ -129,10 +128,6 @@ class CLIPAttention(nn.Module):
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads,
- self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py
index e03705d48f3e..232a63c50689 100644
--- a/vllm/model_executor/models/ovis.py
+++ b/vllm/model_executor/models/ovis.py
@@ -30,6 +30,9 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.layers.quantization.gptq import GPTQConfig
+from vllm.model_executor.layers.quantization.gptq_marlin import (
+ GPTQMarlinConfig)
from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
@@ -48,7 +51,7 @@ from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
OvisConfig)
from vllm.transformers_utils.processors.ovis import OvisProcessor
-from .interfaces import MultiModalEmbeddings, SupportsMultiModal
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
@@ -106,12 +109,14 @@ class VisualTokenizer(torch.nn.Module):
config: BaseVisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
- ):
+ ) -> nn.Module:
model_type = config.backbone_config.model_type
if model_type == "aimv2":
+ # No post rms_norm in Ovis2's AIMv2 ViT.
return AIMv2Model(
config=config.backbone_config,
quant_config=quant_config,
+ require_post_norm=False,
prefix=prefix,
)
elif model_type == "siglip_vision_model":
@@ -124,14 +129,14 @@ class VisualTokenizer(torch.nn.Module):
f"Unsupported visual tokenizer model_type: {model_type}")
@property
- def dtype(self):
+ def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype
@property
- def device(self):
+ def device(self) -> torch.device:
return next(self.head.parameters()).device
- def tokenize(self, logits):
+ def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
@@ -144,7 +149,7 @@ class VisualTokenizer(torch.nn.Module):
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens
- def encode(self, pixel_values):
+ def encode(self, pixel_values: torch.Tensor) -> torch.Tensor:
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]
@@ -395,7 +400,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
info=OvisProcessingInfo,
dummy_inputs=OvisDummyInputsBuilder)
-class Ovis(nn.Module, SupportsMultiModal):
+class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -410,7 +415,7 @@ class Ovis(nn.Module, SupportsMultiModal):
self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config,
- quant_config=quant_config,
+ quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.visual_tokenizer",
)
@@ -421,9 +426,16 @@ class Ovis(nn.Module, SupportsMultiModal):
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
- # TODO(Isotr0py): PP support
- # self.make_empty_intermediate_tensors = (
- # self.language_model.make_empty_intermediate_tensors)
+ self.make_empty_intermediate_tensors = (
+ self.get_language_model().make_empty_intermediate_tensors)
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
+ # seems to avoid vision encoder sections for some models.
+ # See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
+ if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
+ return None
+ return quant_config
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]: