[VLM] Add PP support and fix GPTQ inference for Ovis models (#18958)

Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-31 01:11:44 +08:00 committed by GitHub
parent f49239cb45
commit 5a8641638a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 91 deletions

View File

@ -538,7 +538,7 @@ Specified using `--task generate`.
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -227,6 +227,7 @@ MULTIMODAL_MODELS = {
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),

View File

@ -2,16 +2,23 @@
# A modified implementation of the AIMv2 Transformer # A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2 # inserted here also the image tokenizer used by Ovis2
from collections.abc import Iterable
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config from vllm.transformers_utils.configs.ovis import AIMv2Config
@ -24,29 +31,27 @@ class AIMv2SwiGLUFFN(nn.Module):
in_features = config.hidden_size in_features = config.hidden_size
bias = config.use_bias bias = config.use_bias
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer self.fc13 = MergedColumnParallelLinear(
self.fc1 = ReplicatedLinear(in_features, in_features,
hidden_features, [hidden_features] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1") prefix=f"{prefix}.fc13",
self.fc2 = ReplicatedLinear(hidden_features, )
in_features, self.fc2 = RowParallelLinear(
bias=bias, input_size=hidden_features,
quant_config=quant_config, output_size=in_features,
prefix=f"{prefix}.fc2") bias=bias,
self.fc3 = ReplicatedLinear(in_features, quant_config=quant_config,
hidden_features, prefix=f"{prefix}.fc2",
bias=bias, )
quant_config=quant_config, self.act_fn = SiluAndMul()
prefix=f"{prefix}.fc3")
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x) x, _ = self.fc13(x)
gate, _ = self.fc3(x) x = self.act_fn(x)
x_parallel = F.silu(x_parallel) * gate x, _ = self.fc2(x)
out, _ = self.fc2(x_parallel) return x
return out
class AIMv2PatchEmbed(nn.Module): class AIMv2PatchEmbed(nn.Module):
@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str): prefix: str):
super().__init__() super().__init__()
dim = config.hidden_size self.config = config
self.embed_dim = config.hidden_size
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias) self.head_dim = self.embed_dim // self.num_heads
# self.qkv = QKVParallelLinear( if self.head_dim * self.num_heads != self.embed_dim:
# hidden_size=dim, raise ValueError(
# head_size=dim // config.num_attention_heads, "embed_dim must be divisible by num_heads "
# total_num_heads=config.num_attention_heads, f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
# bias=config.qkv_bias, f" {self.num_heads}).")
# quant_config=quant_config, self.scale = self.head_dim**-0.5
# prefix=f"{prefix}.qkv")
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
# self.proj = RowParallelLinear(input_size=dim,
# output_size=dim,
# bias = config.use_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.proj")
def forward( # todo might implement multiple attn implementations self.qkv = QKVParallelLinear(
self, hidden_size=self.embed_dim,
x: torch.Tensor, head_size=self.head_dim,
mask: Optional[torch.Tensor] = None) -> torch.Tensor: total_num_heads=self.num_heads,
B, N, C = x.shape bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(x) qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
qkv = qkv.reshape(B, N, 3, self.num_heads, x = self.attn(q, k, v)
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x, _ = self.proj(x) x, _ = self.proj(x)
return x return x
@ -141,37 +152,40 @@ class AIMv2Block(nn.Module):
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor, x = x + self.attn(self.norm_1.forward_native(x))
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm_1.forward_native(x), mask)
x = x + self.mlp(self.norm_2.forward_native(x)) x = x + self.mlp(self.norm_2.forward_native(x))
return x return x
class AIMv2Transformer(nn.Module): class AIMv2Transformer(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, def __init__(
prefix: str): self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
super().__init__() super().__init__()
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.post_trunk_norm = RMSNorm(config.hidden_size, if require_post_norm:
eps=config.rms_norm_eps) self.post_trunk_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
else:
self.post_trunk_norm = None
def forward( def forward(self, tokens: torch.Tensor) -> torch.Tensor:
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# they take the -1 as the ref embeddings, like a clip skip # they take the -1 as the ref embeddings, like a clip skip
for block in self.blocks: for block in self.blocks:
tokens = block(tokens, mask) tokens = block(tokens)
# NO NORM IN THE OG IMPLEMENTATION if self.post_trunk_norm is not None:
# tokens = self.post_trunk_norm(tokens) tokens = self.post_trunk_norm(tokens)
return tokens return tokens
@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
def __init__(self, def __init__(self,
config: AIMv2Config, config: AIMv2Config,
quant_config: QuantizationConfig, quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config) self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config, self.trunk = AIMv2Transformer(config,
quant_config=quant_config, quant_config=quant_config,
require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk") prefix=f"{prefix}.trunk")
def forward( def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.preprocessor(pixel_values) x = self.preprocessor(pixel_values)
x = self.trunk(x, mask) x = self.trunk(x)
return x return x
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".fc13", ".fc1", 0),
(".fc13", ".fc3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if (name.startswith("trunk.post_trunk_norm")
and self.trunk.post_trunk_norm is None):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -106,7 +106,6 @@ class CLIPAttention(nn.Module):
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).") f" {self.num_heads}).")
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
@ -129,10 +128,6 @@ class CLIPAttention(nn.Module):
self.attn = MultiHeadAttention(self.num_heads_per_partition, self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale) self.head_dim, self.scale)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -30,6 +30,9 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.aimv2 import AIMv2Model
from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
@ -48,7 +51,7 @@ from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
OvisConfig) OvisConfig)
from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config. # Cannot find the following number from hf config.
@ -106,12 +109,14 @@ class VisualTokenizer(torch.nn.Module):
config: BaseVisualTokenizerConfig, config: BaseVisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ) -> nn.Module:
model_type = config.backbone_config.model_type model_type = config.backbone_config.model_type
if model_type == "aimv2": if model_type == "aimv2":
# No post rms_norm in Ovis2's AIMv2 ViT.
return AIMv2Model( return AIMv2Model(
config=config.backbone_config, config=config.backbone_config,
quant_config=quant_config, quant_config=quant_config,
require_post_norm=False,
prefix=prefix, prefix=prefix,
) )
elif model_type == "siglip_vision_model": elif model_type == "siglip_vision_model":
@ -124,14 +129,14 @@ class VisualTokenizer(torch.nn.Module):
f"Unsupported visual tokenizer model_type: {model_type}") f"Unsupported visual tokenizer model_type: {model_type}")
@property @property
def dtype(self): def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype return next(self.head.parameters()).dtype
@property @property
def device(self): def device(self) -> torch.device:
return next(self.head.parameters()).device return next(self.head.parameters()).device
def tokenize(self, logits): def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
if self.config.tokenize_function == 'softmax': if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1) tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax': elif self.config.tokenize_function == 'gumbel_argmax':
@ -144,7 +149,7 @@ class VisualTokenizer(torch.nn.Module):
f'or st_argmax, but got {self.config.tokenize_function}') f'or st_argmax, but got {self.config.tokenize_function}')
return tokens return tokens
def encode(self, pixel_values): def encode(self, pixel_values: torch.Tensor) -> torch.Tensor:
features = self.backbone(pixel_values) features = self.backbone(pixel_values)
if self.config.drop_cls_token: if self.config.drop_cls_token:
features = features[:, 1:, :] features = features[:, 1:, :]
@ -395,7 +400,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
info=OvisProcessingInfo, info=OvisProcessingInfo,
dummy_inputs=OvisDummyInputsBuilder) dummy_inputs=OvisDummyInputsBuilder)
class Ovis(nn.Module, SupportsMultiModal): class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@ -410,7 +415,7 @@ class Ovis(nn.Module, SupportsMultiModal):
self.visual_tokenizer = VisualTokenizer( self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config, config=config.visual_tokenizer_config,
quant_config=quant_config, quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.visual_tokenizer", prefix=f"{prefix}.visual_tokenizer",
) )
@ -421,9 +426,16 @@ class Ovis(nn.Module, SupportsMultiModal):
text_model_type = self.config.get_text_config().model_type text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support self.make_empty_intermediate_tensors = (
# self.make_empty_intermediate_tensors = ( self.get_language_model().make_empty_intermediate_tensors)
# self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]: self, **kwargs: object) -> Optional[OvisImagePatchInputs]: