mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[VLM] Add PP support and fix GPTQ inference for Ovis models (#18958)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
f49239cb45
commit
5a8641638a
@ -538,7 +538,7 @@ Specified using `--task generate`.
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ |
|
||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -227,6 +227,7 @@ MULTIMODAL_MODELS = {
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||
|
||||
@ -2,16 +2,23 @@
|
||||
|
||||
# A modified implementation of the AIMv2 Transformer
|
||||
# inserted here also the image tokenizer used by Ovis2
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.ovis import AIMv2Config
|
||||
|
||||
|
||||
@ -24,29 +31,27 @@ class AIMv2SwiGLUFFN(nn.Module):
|
||||
in_features = config.hidden_size
|
||||
bias = config.use_bias
|
||||
|
||||
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
|
||||
self.fc1 = ReplicatedLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = ReplicatedLinear(hidden_features,
|
||||
self.fc13 = MergedColumnParallelLinear(
|
||||
in_features,
|
||||
[hidden_features] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
self.fc3 = ReplicatedLinear(in_features,
|
||||
hidden_features,
|
||||
prefix=f"{prefix}.fc13",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
input_size=hidden_features,
|
||||
output_size=in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc3")
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_parallel, _ = self.fc1(x)
|
||||
gate, _ = self.fc3(x)
|
||||
x_parallel = F.silu(x_parallel) * gate
|
||||
out, _ = self.fc2(x_parallel)
|
||||
return out
|
||||
x, _ = self.fc13(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class AIMv2PatchEmbed(nn.Module):
|
||||
@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
dim = config.hidden_size
|
||||
|
||||
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
|
||||
# self.qkv = QKVParallelLinear(
|
||||
# hidden_size=dim,
|
||||
# head_size=dim // config.num_attention_heads,
|
||||
# total_num_heads=config.num_attention_heads,
|
||||
# bias=config.qkv_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv")
|
||||
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
|
||||
# self.proj = RowParallelLinear(input_size=dim,
|
||||
# output_size=dim,
|
||||
# bias = config.use_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.proj")
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
def forward( # todo might implement multiple attn implementations
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
|
||||
x = self.attn(q, k, v)
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
@ -141,37 +152,40 @@ class AIMv2Block(nn.Module):
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + self.attn(self.norm_1.forward_native(x), mask)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(self.norm_1.forward_native(x))
|
||||
x = x + self.mlp(self.norm_2.forward_native(x))
|
||||
return x
|
||||
|
||||
|
||||
class AIMv2Transformer(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
def __init__(
|
||||
self,
|
||||
config: AIMv2Config,
|
||||
quant_config: QuantizationConfig,
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
if require_post_norm:
|
||||
self.post_trunk_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.post_trunk_norm = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
# they take the -1 as the ref embeddings, like a clip skip
|
||||
for block in self.blocks:
|
||||
tokens = block(tokens, mask)
|
||||
# NO NORM IN THE OG IMPLEMENTATION
|
||||
# tokens = self.post_trunk_norm(tokens)
|
||||
tokens = block(tokens)
|
||||
if self.post_trunk_norm is not None:
|
||||
tokens = self.post_trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
|
||||
def __init__(self,
|
||||
config: AIMv2Config,
|
||||
quant_config: QuantizationConfig,
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.preprocessor = AIMv2ViTPreprocessor(config)
|
||||
self.trunk = AIMv2Transformer(config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.trunk")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = self.preprocessor(pixel_values)
|
||||
x = self.trunk(x, mask)
|
||||
x = self.trunk(x)
|
||||
|
||||
return x
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".fc13", ".fc1", 0),
|
||||
(".fc13", ".fc3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in SiglipVisionModel
|
||||
if (name.startswith("trunk.post_trunk_norm")
|
||||
and self.trunk.post_trunk_norm is None):
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -106,7 +106,6 @@ class CLIPAttention(nn.Module):
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
@ -129,10 +128,6 @@ class CLIPAttention(nn.Module):
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -30,6 +30,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.models.aimv2 import AIMv2Model
|
||||
from vllm.model_executor.models.siglip import SiglipVisionModel
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
||||
@ -48,7 +51,7 @@ from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
|
||||
OvisConfig)
|
||||
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following number from hf config.
|
||||
@ -106,12 +109,14 @@ class VisualTokenizer(torch.nn.Module):
|
||||
config: BaseVisualTokenizerConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
) -> nn.Module:
|
||||
model_type = config.backbone_config.model_type
|
||||
if model_type == "aimv2":
|
||||
# No post rms_norm in Ovis2's AIMv2 ViT.
|
||||
return AIMv2Model(
|
||||
config=config.backbone_config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=prefix,
|
||||
)
|
||||
elif model_type == "siglip_vision_model":
|
||||
@ -124,14 +129,14 @@ class VisualTokenizer(torch.nn.Module):
|
||||
f"Unsupported visual tokenizer model_type: {model_type}")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self) -> torch.device:
|
||||
return next(self.head.parameters()).device
|
||||
|
||||
def tokenize(self, logits):
|
||||
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.config.tokenize_function == 'softmax':
|
||||
tokens = softmax(logits, dim=-1)
|
||||
elif self.config.tokenize_function == 'gumbel_argmax':
|
||||
@ -144,7 +149,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
f'or st_argmax, but got {self.config.tokenize_function}')
|
||||
return tokens
|
||||
|
||||
def encode(self, pixel_values):
|
||||
def encode(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
features = self.backbone(pixel_values)
|
||||
if self.config.drop_cls_token:
|
||||
features = features[:, 1:, :]
|
||||
@ -395,7 +400,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
|
||||
@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
|
||||
info=OvisProcessingInfo,
|
||||
dummy_inputs=OvisDummyInputsBuilder)
|
||||
class Ovis(nn.Module, SupportsMultiModal):
|
||||
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -410,7 +415,7 @@ class Ovis(nn.Module, SupportsMultiModal):
|
||||
|
||||
self.visual_tokenizer = VisualTokenizer(
|
||||
config=config.visual_tokenizer_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
)
|
||||
|
||||
@ -421,9 +426,16 @@ class Ovis(nn.Module, SupportsMultiModal):
|
||||
text_model_type = self.config.get_text_config().model_type
|
||||
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||
|
||||
# TODO(Isotr0py): PP support
|
||||
# self.make_empty_intermediate_tensors = (
|
||||
# self.language_model.make_empty_intermediate_tensors)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
# See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user