[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung 2024-08-06 16:55:31 +08:00 committed by GitHub
parent 9118217f58
commit 1f26efbb3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 453 additions and 267 deletions

View File

@ -20,6 +20,9 @@ sentence-transformers # required for embedding
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
# Benchmarking
aiohttp

View File

@ -1,10 +1,11 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
@ -18,9 +19,11 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"USER: <image>\nWhat is the season?\nASSISTANT:",
})
IMAGE_TOKEN_ID = 32000
models = ["llava-hf/llava-1.5-7b-hf"]
models = [
"llava-hf/llava-1.5-7b-hf",
# TODO: Get this model to produce meaningful output in vLLM
# "TIGER-Lab/Mantis-8B-siglip-llama3",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
@ -67,6 +73,17 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: For local use; this isn't tested in CI yet (see TODO above)
if model.startswith("TIGER-Lab/Mantis"):
from mantis.models.mllava import MLlavaProcessor
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
mantis_processor = MLlavaProcessor.from_pretrained(
model, torch_dtype=torch_dtype)
assert isinstance(mantis_processor, MLlavaProcessor)
else:
mantis_processor = None
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -94,6 +111,15 @@ def run_test(
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:
def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output
hf_model.processor = process
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -23,8 +23,6 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
})
IMAGE_TOKEN_ID = 32000
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "

View File

@ -2,7 +2,7 @@ import os
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -20,8 +20,6 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"What is in the picture?",
})
IMAGE_TOKEN_ID = 257152
models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str

View File

@ -6,4 +6,4 @@ from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
# Ensure all model classes can be imported successfully
ModelRegistry.load_model_cls(model_cls)
ModelRegistry.resolve_model_cls([model_cls])

View File

@ -16,7 +16,7 @@ import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, PretrainedConfig
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
@ -143,6 +143,22 @@ def _get_model_initialization_kwargs(
return extra_kwargs
def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
@ -151,15 +167,17 @@ def _initialize_model(
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
model_class, _ = get_model_architecture(model_config)
return model_class(config=model_config.hf_config,
return build_model(
model_class,
model_config.hf_config,
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=multimodal_config,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config,
scheduler_config))
scheduler_config=scheduler_config,
)
class BaseModelLoader(ABC):

View File

@ -28,13 +28,7 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
return ModelRegistry.resolve_model_cls(architectures)
def get_architecture_class_name(model_config: ModelConfig) -> str:

View File

@ -1,6 +1,6 @@
import functools
import importlib
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type
import torch.nn as nn
@ -126,7 +126,7 @@ class ModelRegistry:
return getattr(module, model_cls_name, None)
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
@ -143,6 +143,18 @@ class ModelRegistry:
return ModelRegistry._get_model(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())

View File

@ -1,6 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
@ -14,6 +14,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
patch_size=hf_config.patch_size) + 1
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
@ -291,3 +292,22 @@ class CLIPVisionModel(nn.Module):
@property
def device(self):
return next(self.parameters()).device
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -18,7 +18,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -29,7 +28,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
@ -283,10 +283,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
llm_class = ModelRegistry.load_model_cls(
config.text_config.architectures[0])
self.language_model = llm_class(config.text_config, cache_config,
quant_config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@ -415,24 +413,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model")
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)
# load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
@ -441,5 +431,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model")
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)

View File

@ -1,34 +1,30 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_max_clip_image_tokens, input_processor_for_clip)
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
# TODO(xwjiang): Run benchmark and decide if TP.
@ -67,25 +63,48 @@ def get_max_llava_image_tokens(ctx: InputContext):
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -100,12 +119,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
@ -128,36 +184,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
@ -198,8 +233,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
@ -272,7 +310,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
@ -282,7 +321,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
@ -293,57 +332,33 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)

View File

@ -1,9 +1,10 @@
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
@ -12,23 +13,23 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
logger = init_logger(__name__)
@ -104,7 +105,24 @@ def get_llava_next_image_feature_size(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = num_patches * num_patches
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
@ -116,18 +134,13 @@ def get_llava_next_image_feature_size(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches,
num_patch_height,
num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def get_max_llava_next_image_tokens(ctx: InputContext):
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
@ -155,6 +168,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(
vision_config,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
@ -194,6 +222,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -215,36 +277,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = Sampler()
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
@ -310,8 +351,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
@ -496,7 +540,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
@ -506,7 +551,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
@ -517,57 +562,43 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
# prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
newline_weights = filter_weights(newline_weights, "image_newline")
for name, loaded_weight in newline_weights:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)

View File

@ -2,12 +2,12 @@
within a vision language model."""
import math
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple
import torch
from PIL import Image
from torch import nn
from transformers import SiglipConfig, SiglipVisionConfig
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
@ -22,13 +22,15 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
@ -454,7 +456,7 @@ class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipConfig,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@ -474,7 +476,7 @@ class SiglipEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor]:
) -> Tuple[torch.Tensor, None]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
@ -493,22 +495,27 @@ class SiglipEncoder(nn.Module):
def __init__(
self,
config: SiglipConfig,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
SiglipEncoderLayer(
config,
quant_config=quant_config,
) for _ in range(config.num_hidden_layers)
SiglipEncoderLayer(config, quant_config=quant_config)
for _ in range(num_hidden_layers)
])
def forward(
self,
inputs_embeds: torch.Tensor,
) -> Tuple:
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states)
@ -553,6 +560,7 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.config = config
@ -562,6 +570,7 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
@ -600,11 +609,13 @@ class SiglipVisionModel(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
def get_input_embeddings(self) -> nn.Module:
@ -619,3 +630,19 @@ class SiglipVisionModel(nn.Module):
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -1,22 +1,70 @@
from typing import Dict, List, Protocol, Tuple
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
import torch
import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
"""
Helper function to load weights for inner vLLM models.
See also:
:ref:`init_vllm_registered_model`
"""
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
"""
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
"""
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
return build_model(
model_class,
hf_config,
cache_config,
quant_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
scheduler_config=scheduler_config,
)
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
"""
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
``input_ids``.
Note:
This updates `inputs_embeds` in place.
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == image_token_id)
num_expected_tokens = mask.sum()