mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 15:55:43 +08:00
[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)
This commit is contained in:
parent
0e40ac9b7b
commit
06ed2815e2
@ -1,6 +1,6 @@
|
|||||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||||
within a vision language model."""
|
within a vision language model."""
|
||||||
from typing import Optional, Union
|
from typing import Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
@ -342,6 +343,10 @@ class BlipVisionModel(nn.Module):
|
|||||||
num_hidden_layers_override: Optional[int] = None):
|
num_hidden_layers_override: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.embeddings = BlipVisionEmbeddings(config)
|
self.embeddings = BlipVisionEmbeddings(config)
|
||||||
@ -350,11 +355,61 @@ class BlipVisionModel(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps)
|
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"The original encoder only has {config.num_hidden_layers} "
|
||||||
|
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||||
|
)
|
||||||
|
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||||
|
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
else:
|
||||||
|
# post_layernorm is unused when we extract intermediate features
|
||||||
|
# In this case, we can skip it to conserve memory
|
||||||
|
self.post_layernorm = None
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values)
|
||||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
||||||
|
|
||||||
|
if self.post_layernorm is None:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
return self.post_layernorm(hidden_states)
|
return self.post_layernorm(hidden_states)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
] if self.shard_weight else []
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
layer_count = len(self.encoder.layers)
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# post_layernorm is not needed in BlipVisionModel
|
||||||
|
if (name.startswith("post_layernorm")
|
||||||
|
and self.post_layernorm is None):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# omit layers when num_hidden_layers_override is set
|
||||||
|
if name.startswith("encoder.layers"):
|
||||||
|
layer_idx = int(name.split(".")[2])
|
||||||
|
if layer_idx >= layer_count:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@ -10,11 +10,9 @@ from vllm.attention import AttentionMetadata
|
|||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.opt import OPTModel
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
@ -22,12 +20,8 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
|||||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||||
get_max_blip_image_tokens)
|
get_max_blip_image_tokens)
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
|
||||||
"language_model.lm_head": "lm_head",
|
|
||||||
"language_model.model": "language_model",
|
|
||||||
}
|
|
||||||
|
|
||||||
# We use this internally as placeholders since there is no image token
|
# We use this internally as placeholders since there is no image token
|
||||||
# defined on the HuggingFace repo
|
# defined on the HuggingFace repo
|
||||||
@ -491,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# currently all existing BLIP-2 models have `tie_word_embeddings`
|
|
||||||
# enabled
|
|
||||||
assert config.tie_word_embeddings
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
@ -514,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.language_model = init_vllm_registered_model(
|
||||||
|
config.text_config, cache_config, quant_config)
|
||||||
self.language_model = OPTModel(config.text_config, cache_config,
|
|
||||||
quant_config)
|
|
||||||
|
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def get_lm_head(self):
|
|
||||||
return self.language_model.decoder.embed_tokens
|
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
h = w = self.config.vision_config.image_size
|
h = w = self.config.vision_config.image_size
|
||||||
@ -653,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
@ -663,11 +646,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -676,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
logits = self.logits_processor(self.get_lm_head(), hidden_states,
|
return self.language_model.compute_logits(hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# only doing this for language model part for now.
|
# prepare weight iterators for components
|
||||||
stacked_params_mapping = [
|
weights_group = group_weights_with_prefix(weights)
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
# load vision encoder
|
||||||
if "lm_head.weight" in name:
|
self.vision_model.load_weights(weights_group["vision_model"])
|
||||||
continue
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
# load query tokens
|
||||||
continue
|
for name, loaded_weight in weights_group["query_tokens"]:
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
assert name == ""
|
||||||
if key_to_modify in name:
|
param = self.query_tokens
|
||||||
name = name.replace(key_to_modify, new_key)
|
weight_loader = getattr(param, "weight_loader",
|
||||||
use_default_weight_loading = False
|
default_weight_loader)
|
||||||
if "vision" in name:
|
weight_loader(param, loaded_weight)
|
||||||
if self.vision_model is not None:
|
|
||||||
# BlipVisionModel does not need sharding
|
# load qformer
|
||||||
use_default_weight_loading = True
|
qformer_params_dict = dict(self.qformer.named_parameters())
|
||||||
else:
|
for name, loaded_weight in weights_group["qformer"]:
|
||||||
for (param_name, weight_name,
|
param = qformer_params_dict[name]
|
||||||
shard_id) in stacked_params_mapping:
|
weight_loader = getattr(param, "weight_loader",
|
||||||
if weight_name not in name:
|
default_weight_loader)
|
||||||
continue
|
weight_loader(param, loaded_weight)
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
|
||||||
weight_loader = param.weight_loader
|
# load mlp projector
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
mlp_params_dict = dict(self.language_projection.named_parameters())
|
||||||
break
|
for name, loaded_weight in weights_group["language_projection"]:
|
||||||
else:
|
param = mlp_params_dict[name]
|
||||||
use_default_weight_loading = True
|
weight_loader = getattr(param, "weight_loader",
|
||||||
if use_default_weight_loading:
|
default_weight_loader)
|
||||||
param = params_dict[name]
|
weight_loader(param, loaded_weight)
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
# load llm backbone
|
||||||
weight_loader(param, loaded_weight)
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata
|
|||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -36,8 +35,6 @@ from vllm.utils import print_warning_once
|
|||||||
|
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# These configs are not part of the model config but the preprocessor
|
# These configs are not part of the model config but the preprocessor
|
||||||
# and processor files, so we hardcode them in the model file for now.
|
# and processor files, so we hardcode them in the model file for now.
|
||||||
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
||||||
|
|||||||
@ -391,6 +391,7 @@ class CLIPVisionModel(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
num_hidden_layers_override: Optional[int] = None):
|
num_hidden_layers_override: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
num_heads = config.num_attention_heads
|
num_heads = config.num_attention_heads
|
||||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||||
@ -400,10 +401,6 @@ class CLIPVisionModel(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override)
|
num_hidden_layers_override=num_hidden_layers_override)
|
||||||
|
|
||||||
@property
|
|
||||||
def _require_post_layernorm(self) -> bool:
|
|
||||||
return self.vision_model.post_layernorm is not None
|
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
return self.vision_model(pixel_values)
|
return self.vision_model(pixel_values)
|
||||||
|
|
||||||
@ -425,12 +422,12 @@ class CLIPVisionModel(nn.Module):
|
|||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# post_layernorm is not needed in CLIPVisionModel
|
# post_layernorm is not needed in CLIPVisionModel
|
||||||
if ("vision_model.post_layernorm" in name
|
if (name.startswith("vision_model.post_layernorm")
|
||||||
and not self._require_post_layernorm):
|
and self.vision_model.post_layernorm is None):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# omit layers when num_hidden_layers_override is set
|
# omit layers when num_hidden_layers_override is set
|
||||||
if "vision_model.encoder.layers." in name:
|
if name.startswith("vision_model.encoder.layers"):
|
||||||
layer_idx = int(name.split(".")[3])
|
layer_idx = int(name.split(".")[3])
|
||||||
if layer_idx >= layer_count:
|
if layer_idx >= layer_count:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -28,7 +28,6 @@ from transformers import FuyuConfig, FuyuImageProcessor
|
|||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -45,8 +44,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
|||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# Cannot find the following 2 numbers from hf config.
|
# Cannot find the following 2 numbers from hf config.
|
||||||
_IMAGE_TOKEN_ID = 71011
|
_IMAGE_TOKEN_ID = 71011
|
||||||
_NEWLINE_TOKEN_ID = 71019
|
_NEWLINE_TOKEN_ID = 71019
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from typing_extensions import NotRequired
|
|||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -32,13 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|||||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
|
||||||
"language_model.lm_head": "lm_head",
|
|
||||||
"language_model.model": "language_model",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
|
|||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -32,8 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|||||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# For profile run
|
# For profile run
|
||||||
_MAX_FRAMES_PER_VIDEO = 32
|
_MAX_FRAMES_PER_VIDEO = 32
|
||||||
_MAX_NUM_VIDEOS = 1
|
_MAX_NUM_VIDEOS = 1
|
||||||
|
|||||||
@ -37,7 +37,6 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -59,8 +58,6 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
|||||||
|
|
||||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
"llm.lm_head": "lm_head",
|
"llm.lm_head": "lm_head",
|
||||||
"llm.model": "llm",
|
"llm.model": "llm",
|
||||||
|
|||||||
@ -501,6 +501,7 @@ class SiglipVisionModel(nn.Module):
|
|||||||
num_hidden_layers_override: Optional[int] = None,
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
num_heads = config.num_attention_heads
|
num_heads = config.num_attention_heads
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||||
@ -511,10 +512,6 @@ class SiglipVisionModel(nn.Module):
|
|||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def _require_post_layernorm(self) -> bool:
|
|
||||||
return self.vision_model.post_layernorm is not None
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
return self.vision_model.embeddings.patch_embedding
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
@ -540,12 +537,12 @@ class SiglipVisionModel(nn.Module):
|
|||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# post_layernorm is optional in SiglipVisionModel
|
# post_layernorm is optional in SiglipVisionModel
|
||||||
if ("vision_model.post_layernorm" in name
|
if (name.startswith("vision_model.post_layernorm")
|
||||||
and not self._require_post_layernorm):
|
and self.vision_model.post_layernorm is None):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# omit layers when num_hidden_layers_override is set
|
# omit layers when num_hidden_layers_override is set
|
||||||
if "vision_model.encoder.layers." in name:
|
if name.startswith("vision_model.encoder.layers"):
|
||||||
layer_idx = int(name.split(".")[3])
|
layer_idx = int(name.split(".")[3])
|
||||||
if layer_idx >= layer_count:
|
if layer_idx >= layer_count:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
|
|||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.inputs.data import LLMInputs
|
from vllm.inputs.data import LLMInputs
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -43,8 +42,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
|||||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class UltravoxAudioFeatureInputs(TypedDict):
|
class UltravoxAudioFeatureInputs(TypedDict):
|
||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user