mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
9118217f58
commit
1f26efbb3a
@ -20,6 +20,9 @@ sentence-transformers # required for embedding
|
|||||||
compressed-tensors==0.4.0 # required for compressed-tensors
|
compressed-tensors==0.4.0 # required for compressed-tensors
|
||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
|
|
||||||
|
# TODO: Add this after fully implementing llava(mantis)
|
||||||
|
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
|
||||||
|
|
||||||
# Benchmarking
|
# Benchmarking
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
@ -18,9 +19,11 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||||
})
|
})
|
||||||
|
|
||||||
IMAGE_TOKEN_ID = 32000
|
models = [
|
||||||
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
# TODO: Get this model to produce meaningful output in vLLM
|
||||||
|
# "TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
|||||||
"""Sanitize vllm output to be comparable with hf output."""
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
output_ids, output_str, out_logprobs = vllm_output
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_index
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
hf_output_ids = [
|
hf_output_ids = [
|
||||||
token_id for idx, token_id in enumerate(output_ids)
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||||
]
|
]
|
||||||
|
|
||||||
assert output_str[0] == " "
|
assert output_str[0] == " "
|
||||||
@ -67,6 +73,17 @@ def run_test(
|
|||||||
Note, the text input is also adjusted to abide by vllm contract.
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
The text output is sanitized to be able to compare with hf.
|
The text output is sanitized to be able to compare with hf.
|
||||||
"""
|
"""
|
||||||
|
# NOTE: For local use; this isn't tested in CI yet (see TODO above)
|
||||||
|
if model.startswith("TIGER-Lab/Mantis"):
|
||||||
|
from mantis.models.mllava import MLlavaProcessor
|
||||||
|
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
mantis_processor = MLlavaProcessor.from_pretrained(
|
||||||
|
model, torch_dtype=torch_dtype)
|
||||||
|
assert isinstance(mantis_processor, MLlavaProcessor)
|
||||||
|
else:
|
||||||
|
mantis_processor = None
|
||||||
|
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
inputs_per_image = [(
|
inputs_per_image = [(
|
||||||
@ -94,6 +111,15 @@ def run_test(
|
|||||||
]
|
]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||||
|
if mantis_processor is not None:
|
||||||
|
|
||||||
|
def process(*args, **kwargs):
|
||||||
|
output = mantis_processor(*args, **kwargs)
|
||||||
|
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
hf_model.processor = process
|
||||||
|
|
||||||
hf_outputs_per_image = [
|
hf_outputs_per_image = [
|
||||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional, Tuple, Type, overload
|
from typing import List, Optional, Tuple, Type, overload
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
@ -23,8 +23,6 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
|
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
|
||||||
})
|
})
|
||||||
|
|
||||||
IMAGE_TOKEN_ID = 32000
|
|
||||||
|
|
||||||
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
|
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
|
||||||
|
|
||||||
|
|
||||||
@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
|||||||
"""Sanitize vllm output to be comparable with hf output."""
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
output_ids, output_str, out_logprobs = vllm_output
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_index
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
hf_output_ids = [
|
hf_output_ids = [
|
||||||
token_id for idx, token_id in enumerate(output_ids)
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||||
]
|
]
|
||||||
|
|
||||||
assert output_str[0] == " "
|
assert output_str[0] == " "
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import os
|
|||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
@ -20,8 +20,6 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"What is in the picture?",
|
"What is in the picture?",
|
||||||
})
|
})
|
||||||
|
|
||||||
IMAGE_TOKEN_ID = 257152
|
|
||||||
|
|
||||||
models = ["google/paligemma-3b-mix-224"]
|
models = ["google/paligemma-3b-mix-224"]
|
||||||
|
|
||||||
# ROCm Triton FA can run into compilation issues with these models due to,
|
# ROCm Triton FA can run into compilation issues with these models due to,
|
||||||
@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
|||||||
"""Sanitize vllm output to be comparable with hf output."""
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
output_ids, output_str, out_logprobs = vllm_output
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_index
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
hf_output_ids = [
|
hf_output_ids = [
|
||||||
token_id for idx, token_id in enumerate(output_ids)
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||||
]
|
]
|
||||||
|
|
||||||
hf_output_str = output_str
|
hf_output_str = output_str
|
||||||
|
|||||||
@ -6,4 +6,4 @@ from vllm.model_executor.models import _MODELS, ModelRegistry
|
|||||||
@pytest.mark.parametrize("model_cls", _MODELS)
|
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||||
def test_registry_imports(model_cls):
|
def test_registry_imports(model_cls):
|
||||||
# Ensure all model classes can be imported successfully
|
# Ensure all model classes can be imported successfully
|
||||||
ModelRegistry.load_model_cls(model_cls)
|
ModelRegistry.resolve_model_cls([model_cls])
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||||
@ -143,6 +143,22 @@ def _get_model_initialization_kwargs(
|
|||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig],
|
||||||
|
quant_config: Optional[QuantizationConfig], *,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
multimodal_config: Optional[MultiModalConfig],
|
||||||
|
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
|
||||||
|
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
|
||||||
|
multimodal_config,
|
||||||
|
scheduler_config)
|
||||||
|
|
||||||
|
return model_class(config=hf_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
**extra_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_model(
|
def _initialize_model(
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
@ -151,15 +167,17 @@ def _initialize_model(
|
|||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
|
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
|
||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_class = get_model_architecture(model_config)[0]
|
model_class, _ = get_model_architecture(model_config)
|
||||||
quant_config = _get_quantization_config(model_config, load_config)
|
|
||||||
|
|
||||||
return model_class(config=model_config.hf_config,
|
return build_model(
|
||||||
cache_config=cache_config,
|
model_class,
|
||||||
quant_config=quant_config,
|
model_config.hf_config,
|
||||||
**_get_model_initialization_kwargs(
|
quant_config=_get_quantization_config(model_config, load_config),
|
||||||
model_class, lora_config, multimodal_config,
|
lora_config=lora_config,
|
||||||
scheduler_config))
|
multimodal_config=multimodal_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
|
|||||||
@ -28,13 +28,7 @@ def get_model_architecture(
|
|||||||
and "MixtralForCausalLM" in architectures):
|
and "MixtralForCausalLM" in architectures):
|
||||||
architectures = ["QuantMixtralForCausalLM"]
|
architectures = ["QuantMixtralForCausalLM"]
|
||||||
|
|
||||||
for arch in architectures:
|
return ModelRegistry.resolve_model_cls(architectures)
|
||||||
model_cls = ModelRegistry.load_model_cls(arch)
|
|
||||||
if model_cls is not None:
|
|
||||||
return (model_cls, arch)
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, List, Optional, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ class ModelRegistry:
|
|||||||
return getattr(module, model_cls_name, None)
|
return getattr(module, model_cls_name, None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
if model_arch in _OOT_MODELS:
|
if model_arch in _OOT_MODELS:
|
||||||
return _OOT_MODELS[model_arch]
|
return _OOT_MODELS[model_arch]
|
||||||
if model_arch not in _MODELS:
|
if model_arch not in _MODELS:
|
||||||
@ -143,6 +143,18 @@ class ModelRegistry:
|
|||||||
|
|
||||||
return ModelRegistry._get_model(model_arch)
|
return ModelRegistry._get_model(model_arch)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_model_cls(
|
||||||
|
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry._try_load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return (model_cls, arch)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_archs() -> List[str]:
|
def get_supported_archs() -> List[str]:
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||||
within a vision language model."""
|
within a vision language model."""
|
||||||
from typing import Optional
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||||
repeat_and_pad_image_tokens)
|
repeat_and_pad_image_tokens)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
|
|||||||
|
|
||||||
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
||||||
return get_clip_num_patches(image_size=hf_config.image_size,
|
return get_clip_num_patches(image_size=hf_config.image_size,
|
||||||
patch_size=hf_config.patch_size)
|
patch_size=hf_config.patch_size) + 1
|
||||||
|
|
||||||
|
|
||||||
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
||||||
@ -291,3 +292,22 @@ class CLIPVisionModel(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
layer_count = len(self.vision_model.encoder.layers)
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# post_layernorm is not needed in CLIPVisionModel
|
||||||
|
if "vision_model.post_layernorm" in name:
|
||||||
|
continue
|
||||||
|
# omit layers when num_hidden_layers_override is set
|
||||||
|
if "vision_model.encoder.layers." in name:
|
||||||
|
layer_idx = int(name.split(".")[3])
|
||||||
|
if layer_idx >= layer_count:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from vllm.config import CacheConfig, MultiModalConfig
|
|||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models import ModelRegistry
|
|
||||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -29,7 +28,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
get_clip_num_patches)
|
get_clip_num_patches)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsVision
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
|
merge_vision_embeddings)
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -283,10 +283,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
self.vision_model = InternVisionModel(
|
self.vision_model = InternVisionModel(
|
||||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
||||||
|
|
||||||
llm_class = ModelRegistry.load_model_cls(
|
self.language_model = init_vllm_registered_model(
|
||||||
config.text_config.architectures[0])
|
config.text_config, cache_config, quant_config)
|
||||||
self.language_model = llm_class(config.text_config, cache_config,
|
|
||||||
quant_config)
|
|
||||||
|
|
||||||
vit_hidden_size = config.vision_config.hidden_size
|
vit_hidden_size = config.vision_config.hidden_size
|
||||||
llm_hidden_size = config.text_config.hidden_size
|
llm_hidden_size = config.text_config.hidden_size
|
||||||
@ -415,24 +413,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
return self.language_model.sample(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
|
|
||||||
prefix: str):
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
name = name.split(".")
|
|
||||||
if prefix == name.pop(0):
|
|
||||||
name = ".".join(name)
|
|
||||||
yield name, loaded_weight
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
||||||
|
|
||||||
# load vision encoder
|
# load vision encoder
|
||||||
vit_weights = self._filter_weights(vit_weights, "vision_model")
|
vit_weights = filter_weights(vit_weights, "vision_model")
|
||||||
self.vision_model.load_weights(vit_weights)
|
self.vision_model.load_weights(vit_weights)
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
|
mlp_weights = filter_weights(mlp_weights, "mlp1")
|
||||||
mlp_params_dict = dict(self.mlp1.named_parameters())
|
mlp_params_dict = dict(self.mlp1.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in mlp_weights:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
@ -441,5 +431,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = self._filter_weights(llm_weights, "language_model")
|
llm_weights = filter_weights(llm_weights, "language_model")
|
||||||
self.language_model.load_weights(llm_weights)
|
self.language_model.load_weights(llm_weights)
|
||||||
|
|||||||
@ -1,34 +1,30 @@
|
|||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
import itertools
|
||||||
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import CLIPVisionConfig, LlavaConfig
|
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
|
||||||
from vllm.model_executor.models.llama import LlamaModel
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
get_max_clip_image_tokens, input_processor_for_clip)
|
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||||
|
input_processor_for_clip)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsVision
|
||||||
from .utils import merge_vision_embeddings
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
input_processor_for_siglip)
|
||||||
"language_model.lm_head": "lm_head",
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
"language_model.model": "language_model",
|
merge_vision_embeddings)
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(xwjiang): Run benchmark and decide if TP.
|
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||||
@ -67,25 +63,48 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
|||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
return get_max_clip_image_tokens(vision_config)
|
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
strategy = hf_config.vision_feature_select_strategy
|
||||||
raise NotImplementedError(msg)
|
if strategy == "default":
|
||||||
|
return num_image_tokens - 1
|
||||||
|
elif strategy == "full":
|
||||||
|
return num_image_tokens
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
seq_data = dummy_seq_data_for_clip(
|
seq_data = dummy_seq_data_for_clip(
|
||||||
vision_config,
|
vision_config,
|
||||||
seq_len,
|
seq_len,
|
||||||
image_token_id=hf_config.image_token_index,
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_data = dummy_image_for_clip(vision_config)
|
mm_data = dummy_image_for_clip(vision_config)
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
seq_data = dummy_seq_data_for_siglip(
|
||||||
|
vision_config,
|
||||||
|
seq_len,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = dummy_image_for_siglip(vision_config)
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -100,12 +119,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
return input_processor_for_clip(
|
return input_processor_for_clip(
|
||||||
model_config,
|
model_config,
|
||||||
vision_config,
|
vision_config,
|
||||||
llm_inputs,
|
llm_inputs,
|
||||||
image_token_id=hf_config.image_token_index,
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
return input_processor_for_siglip(
|
||||||
|
model_config,
|
||||||
|
vision_config,
|
||||||
|
llm_inputs,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_vision_tower(hf_config: LlavaConfig):
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
# Initialize the vision tower only up to the required feature layer
|
||||||
|
vision_feature_layer = hf_config.vision_feature_layer
|
||||||
|
if vision_feature_layer < 0:
|
||||||
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||||
|
+ vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = vision_feature_layer + 1
|
||||||
|
|
||||||
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
|
return CLIPVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
return SiglipVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
@ -128,36 +184,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
# Initialize the vision tower only up to the required feature layer
|
|
||||||
vision_feature_layer = config.vision_feature_layer
|
|
||||||
if vision_feature_layer < 0:
|
|
||||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
|
||||||
+ vision_feature_layer + 1
|
|
||||||
else:
|
|
||||||
num_hidden_layers = vision_feature_layer + 1
|
|
||||||
|
|
||||||
# TODO: Optionally initializes this for supporting embeddings.
|
# TODO: Optionally initializes this for supporting embeddings.
|
||||||
self.vision_tower = CLIPVisionModel(
|
self.vision_tower = _init_vision_tower(config)
|
||||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||||
vision_hidden_size=config.vision_config.hidden_size,
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
text_hidden_size=config.text_config.hidden_size,
|
text_hidden_size=config.text_config.hidden_size,
|
||||||
projector_hidden_act=config.projector_hidden_act)
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.language_model = init_vllm_registered_model(
|
||||||
self.language_model = LlamaModel(config.text_config, cache_config,
|
config.text_config, cache_config, quant_config)
|
||||||
quant_config)
|
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
self.unpadded_vocab_size,
|
|
||||||
config.text_config.hidden_size,
|
|
||||||
org_num_embeddings=self.language_model.org_vocab_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.text_config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
h = w = self.config.vision_config.image_size
|
h = w = self.config.vision_config.image_size
|
||||||
@ -198,8 +233,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
def _image_pixels_to_features(
|
||||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
# this is already done inside the vision tower
|
# this is already done inside the vision tower
|
||||||
@ -272,7 +310,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
@ -282,68 +321,44 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
None,
|
None,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
return self.language_model.compute_logits(hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# only doing this for language model part for now.
|
# prepare weight iterators for components
|
||||||
stacked_params_mapping = [
|
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
("qkv_proj", "q_proj", "q"),
|
# load vision encoder
|
||||||
("qkv_proj", "k_proj", "k"),
|
vit_weights = filter_weights(vit_weights, "vision_tower")
|
||||||
("qkv_proj", "v_proj", "v"),
|
self.vision_tower.load_weights(vit_weights)
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
# load mlp projector
|
||||||
]
|
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
||||||
params_dict = dict(self.named_parameters())
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in mlp_weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
param = mlp_params_dict[name]
|
||||||
continue
|
weight_loader = getattr(param, "weight_loader",
|
||||||
# post_layernorm is not needed in CLIPVisionModel
|
default_weight_loader)
|
||||||
if "vision_model.post_layernorm" in name:
|
weight_loader(param, loaded_weight)
|
||||||
continue
|
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
# load llm backbone
|
||||||
if key_to_modify in name:
|
llm_weights = filter_weights(llm_weights, "language_model")
|
||||||
name = name.replace(key_to_modify, new_key)
|
self.language_model.load_weights(llm_weights)
|
||||||
use_default_weight_loading = False
|
|
||||||
if "vision" in name:
|
|
||||||
if self.vision_tower is not None:
|
|
||||||
# We only do sharding for language model and
|
|
||||||
# not vision model for now.
|
|
||||||
use_default_weight_loading = True
|
|
||||||
else:
|
|
||||||
for (param_name, weight_name,
|
|
||||||
shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
use_default_weight_loading = True
|
|
||||||
if use_default_weight_loading and name in params_dict:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
|
import itertools
|
||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig, LlavaNextConfig
|
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
get_anyres_image_grid_shape, unpad_image)
|
get_anyres_image_grid_shape, unpad_image)
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@ -12,23 +13,23 @@ from vllm.attention import AttentionMetadata
|
|||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
|
||||||
from vllm.model_executor.models.llama import LlamaModel
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
|
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||||
get_clip_patch_grid_length, input_processor_for_clip)
|
get_clip_patch_grid_length, input_processor_for_clip)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsVision
|
||||||
from .llava import LlavaMultiModalProjector
|
from .llava import LlavaMultiModalProjector
|
||||||
from .utils import merge_vision_embeddings
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||||
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||||
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
|
merge_vision_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -104,30 +105,42 @@ def get_llava_next_image_feature_size(
|
|||||||
image_size=vision_config.image_size,
|
image_size=vision_config.image_size,
|
||||||
patch_size=vision_config.patch_size,
|
patch_size=vision_config.patch_size,
|
||||||
)
|
)
|
||||||
base_feature_size = num_patches * num_patches
|
base_feature_size = get_clip_image_feature_size(vision_config)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
num_patches = get_siglip_patch_grid_length(
|
||||||
image_size=(input_height, input_width),
|
image_size=vision_config.image_size,
|
||||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
patch_size=vision_config.patch_size,
|
||||||
patch_size=vision_config.image_size,
|
|
||||||
)
|
)
|
||||||
|
base_feature_size = get_siglip_image_feature_size(vision_config)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
(
|
strategy = hf_config.vision_feature_select_strategy
|
||||||
unpadded_feature_size,
|
if strategy == "default":
|
||||||
newline_feature_size,
|
base_feature_size -= 1
|
||||||
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
elif strategy == "full":
|
||||||
num_patches,
|
pass
|
||||||
num_patch_height,
|
else:
|
||||||
num_patch_width)
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
|
image_size=(input_height, input_width),
|
||||||
|
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||||
|
patch_size=vision_config.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
(
|
||||||
raise NotImplementedError(msg)
|
unpadded_feature_size,
|
||||||
|
newline_feature_size,
|
||||||
|
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
||||||
|
num_patches, num_patch_height,
|
||||||
|
num_patch_width)
|
||||||
|
|
||||||
|
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||||
|
|
||||||
|
|
||||||
def get_max_llava_next_image_tokens(ctx: InputContext):
|
def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||||
|
|
||||||
return get_llava_next_image_feature_size(
|
return get_llava_next_image_feature_size(
|
||||||
ctx.get_hf_config(LlavaNextConfig),
|
ctx.get_hf_config(LlavaNextConfig),
|
||||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
@ -155,6 +168,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
|||||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return seq_data, mm_data
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
seq_data = dummy_seq_data_for_siglip(
|
||||||
|
vision_config,
|
||||||
|
seq_len,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = dummy_image_for_siglip(
|
||||||
|
vision_config,
|
||||||
|
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
|
)
|
||||||
|
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
@ -194,6 +222,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
image_token_id=hf_config.image_token_index,
|
image_token_id=hf_config.image_token_index,
|
||||||
image_feature_size_override=image_feature_size,
|
image_feature_size_override=image_feature_size,
|
||||||
)
|
)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
return input_processor_for_siglip(
|
||||||
|
model_config,
|
||||||
|
vision_config,
|
||||||
|
llm_inputs,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
image_feature_size_override=image_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_vision_tower(hf_config: LlavaNextConfig):
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
# Initialize the vision tower only up to the required feature layer
|
||||||
|
vision_feature_layer = hf_config.vision_feature_layer
|
||||||
|
if vision_feature_layer < 0:
|
||||||
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||||
|
+ vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = vision_feature_layer + 1
|
||||||
|
|
||||||
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
|
return CLIPVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
return SiglipVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
@ -215,36 +277,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
# Initialize the vision tower only up to the required feature layer
|
|
||||||
vision_feature_layer = config.vision_feature_layer
|
|
||||||
if vision_feature_layer < 0:
|
|
||||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
|
||||||
+ vision_feature_layer + 1
|
|
||||||
else:
|
|
||||||
num_hidden_layers = vision_feature_layer + 1
|
|
||||||
|
|
||||||
# TODO: Optionally initializes this for supporting embeddings.
|
# TODO: Optionally initializes this for supporting embeddings.
|
||||||
self.vision_tower = CLIPVisionModel(
|
self.vision_tower = _init_vision_tower(config)
|
||||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||||
vision_hidden_size=config.vision_config.hidden_size,
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
text_hidden_size=config.text_config.hidden_size,
|
text_hidden_size=config.text_config.hidden_size,
|
||||||
projector_hidden_act=config.projector_hidden_act)
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.language_model = init_vllm_registered_model(
|
||||||
self.language_model = LlamaModel(config.text_config, cache_config,
|
config.text_config, cache_config, quant_config)
|
||||||
quant_config)
|
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
self.unpadded_vocab_size,
|
|
||||||
config.text_config.hidden_size,
|
|
||||||
org_num_embeddings=self.language_model.org_vocab_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.text_config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
self.image_newline = nn.Parameter(
|
self.image_newline = nn.Parameter(
|
||||||
torch.empty(config.text_config.hidden_size))
|
torch.empty(config.text_config.hidden_size))
|
||||||
@ -310,8 +351,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
def _image_pixels_to_features(
|
||||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
# this is already done inside the vision tower
|
# this is already done inside the vision tower
|
||||||
@ -496,7 +540,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
@ -506,68 +551,54 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
None,
|
None,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
return self.language_model.compute_logits(hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# only doing this for language model part for now.
|
# prepare weight iterators for components
|
||||||
stacked_params_mapping = [
|
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
||||||
# (param_name, shard_name, shard_id)
|
weights, 4)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
# load vision encoder
|
||||||
("qkv_proj", "v_proj", "v"),
|
vit_weights = filter_weights(vit_weights, "vision_tower")
|
||||||
("gate_up_proj", "gate_proj", 0),
|
self.vision_tower.load_weights(vit_weights)
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
# load mlp projector
|
||||||
params_dict = dict(self.named_parameters())
|
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
||||||
for name, loaded_weight in weights:
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
if "rotary_emb.inv_freq" in name:
|
for name, loaded_weight in mlp_weights:
|
||||||
continue
|
param = mlp_params_dict[name]
|
||||||
# post_layernorm is not needed in CLIPVisionModel
|
weight_loader = getattr(param, "weight_loader",
|
||||||
if "vision_model.post_layernorm" in name:
|
default_weight_loader)
|
||||||
continue
|
weight_loader(param, loaded_weight)
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
|
||||||
if key_to_modify in name:
|
# load newline
|
||||||
name = name.replace(key_to_modify, new_key)
|
newline_weights = filter_weights(newline_weights, "image_newline")
|
||||||
use_default_weight_loading = False
|
for name, loaded_weight in newline_weights:
|
||||||
if "vision" in name:
|
assert name == ""
|
||||||
if self.vision_tower is not None:
|
param = self.image_newline
|
||||||
# We only do sharding for language model and
|
weight_loader = getattr(param, "weight_loader",
|
||||||
# not vision model for now.
|
default_weight_loader)
|
||||||
use_default_weight_loading = True
|
weight_loader(param, loaded_weight)
|
||||||
else:
|
|
||||||
for (param_name, weight_name,
|
# load llm backbone
|
||||||
shard_id) in stacked_params_mapping:
|
llm_weights = filter_weights(llm_weights, "language_model")
|
||||||
if weight_name not in name:
|
self.language_model.load_weights(llm_weights)
|
||||||
continue
|
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
use_default_weight_loading = True
|
|
||||||
if use_default_weight_loading and name in params_dict:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|||||||
@ -2,12 +2,12 @@
|
|||||||
within a vision language model."""
|
within a vision language model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import SiglipConfig, SiglipVisionConfig
|
from transformers import SiglipVisionConfig
|
||||||
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||||
from vllm_flash_attn import flash_attn_func
|
from vllm_flash_attn import flash_attn_func
|
||||||
from xformers.ops import memory_efficient_attention
|
from xformers.ops import memory_efficient_attention
|
||||||
@ -22,13 +22,15 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||||
repeat_and_pad_image_tokens)
|
repeat_and_pad_image_tokens)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
|
||||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
assert image_size % patch_size == 0
|
# Since interpolation is applied, the image size need not be divisible
|
||||||
|
# assert image_size % patch_size == 0
|
||||||
return image_size // patch_size
|
return image_size // patch_size
|
||||||
|
|
||||||
|
|
||||||
@ -454,7 +456,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SiglipConfig,
|
config: SiglipVisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -474,7 +476,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor, None]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
@ -493,22 +495,27 @@ class SiglipEncoder(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SiglipConfig,
|
config: SiglipVisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
if num_hidden_layers_override is None:
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
else:
|
||||||
|
num_hidden_layers = num_hidden_layers_override
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
SiglipEncoderLayer(
|
SiglipEncoderLayer(config, quant_config=quant_config)
|
||||||
config,
|
for _ in range(num_hidden_layers)
|
||||||
quant_config=quant_config,
|
|
||||||
) for _ in range(config.num_hidden_layers)
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
) -> Tuple:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
hidden_states, _ = encoder_layer(hidden_states)
|
hidden_states, _ = encoder_layer(hidden_states)
|
||||||
@ -553,6 +560,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: SiglipVisionConfig,
|
config: SiglipVisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -562,6 +570,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
@ -600,11 +609,13 @@ class SiglipVisionModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: SiglipVisionConfig,
|
config: SiglipVisionConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_model = SiglipVisionTransformer(
|
self.vision_model = SiglipVisionTransformer(
|
||||||
config,
|
config,
|
||||||
quant_config,
|
quant_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
@ -619,3 +630,19 @@ class SiglipVisionModel(nn.Module):
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
layer_count = len(self.vision_model.encoder.layers)
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# omit layers when num_hidden_layers_override is set
|
||||||
|
if "vision_model.encoder.layers." in name:
|
||||||
|
layer_idx = int(name.split(".")[3])
|
||||||
|
if layer_idx >= layer_count:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@ -1,22 +1,70 @@
|
|||||||
from typing import Dict, List, Protocol, Tuple
|
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||||
|
SchedulerConfig)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.loader import build_model
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.multimodal import BatchedTensors
|
from vllm.multimodal import BatchedTensors
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
|
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
||||||
|
"""
|
||||||
|
Helper function to load weights for inner vLLM models.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:ref:`init_vllm_registered_model`
|
||||||
|
"""
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
name = name.split(".")
|
||||||
|
if prefix == name.pop(0):
|
||||||
|
name = ".".join(name)
|
||||||
|
yield name, loaded_weight
|
||||||
|
|
||||||
|
|
||||||
|
def init_vllm_registered_model(
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig],
|
||||||
|
quant_config: Optional[QuantizationConfig],
|
||||||
|
*,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
multimodal_config: Optional[MultiModalConfig] = None,
|
||||||
|
scheduler_config: Optional[SchedulerConfig] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Helper function to initialize an inner model registered to vLLM,
|
||||||
|
based on the arguments passed to the outer vLLM model.
|
||||||
|
"""
|
||||||
|
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
|
||||||
|
|
||||||
|
return build_model(
|
||||||
|
model_class,
|
||||||
|
hf_config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
multimodal_config=multimodal_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
vision_embeddings: BatchedTensors,
|
vision_embeddings: BatchedTensors,
|
||||||
image_token_id: int) -> torch.Tensor:
|
image_token_id: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
|
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||||
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
|
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
|
||||||
|
``input_ids``.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This updates `inputs_embeds` in place.
|
This updates ``inputs_embeds`` in place.
|
||||||
"""
|
"""
|
||||||
mask = (input_ids == image_token_id)
|
mask = (input_ids == image_token_id)
|
||||||
num_expected_tokens = mask.sum()
|
num_expected_tokens = mask.sum()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user