[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)

This commit is contained in:
Cyrus Leung 2024-09-22 20:24:21 +08:00 committed by GitHub
parent 0e40ac9b7b
commit 06ed2815e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 112 additions and 113 deletions

View File

@ -1,6 +1,6 @@
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Optional, Union from typing import Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
@ -342,6 +343,10 @@ class BlipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.config = config self.config = config
self.embeddings = BlipVisionEmbeddings(config) self.embeddings = BlipVisionEmbeddings(config)
@ -350,11 +355,61 @@ class BlipVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states) hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states) return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in BlipVisionModel
if (name.startswith("post_layernorm")
and self.post_layernorm is None):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("encoder.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -10,11 +10,9 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
@ -22,12 +20,8 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo # defined on the HuggingFace repo
@ -491,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__() super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -514,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
bias=True, bias=True,
) )
self.quant_config = quant_config self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.language_model = OPTModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
self.sampler = Sampler()
def get_lm_head(self):
return self.language_model.decoder.embed_tokens
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
@ -653,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
@ -663,11 +646,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
@ -676,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
return logits
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now. # prepare weight iterators for components
stacked_params_mapping = [ weights_group = group_weights_with_prefix(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: # load vision encoder
if "lm_head.weight" in name: self.vision_model.load_weights(weights_group["vision_model"])
continue
if "rotary_emb.inv_freq" in name: # load query tokens
continue for name, loaded_weight in weights_group["query_tokens"]:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): assert name == ""
if key_to_modify in name: param = self.query_tokens
name = name.replace(key_to_modify, new_key) weight_loader = getattr(param, "weight_loader",
use_default_weight_loading = False default_weight_loader)
if "vision" in name: weight_loader(param, loaded_weight)
if self.vision_model is not None:
# BlipVisionModel does not need sharding # load qformer
use_default_weight_loading = True qformer_params_dict = dict(self.qformer.named_parameters())
else: for name, loaded_weight in weights_group["qformer"]:
for (param_name, weight_name, param = qformer_params_dict[name]
shard_id) in stacked_params_mapping: weight_loader = getattr(param, "weight_loader",
if weight_name not in name: default_weight_loader)
continue weight_loader(param, loaded_weight)
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader # load mlp projector
weight_loader(param, loaded_weight, shard_id) mlp_params_dict = dict(self.language_projection.named_parameters())
break for name, loaded_weight in weights_group["language_projection"]:
else: param = mlp_params_dict[name]
use_default_weight_loading = True weight_loader = getattr(param, "weight_loader",
if use_default_weight_loading: default_weight_loader)
param = params_dict[name] weight_loader(param, loaded_weight)
weight_loader = getattr(param, "weight_loader",
default_weight_loader) # load llm backbone
weight_loader(param, loaded_weight) self.language_model.load_weights(weights_group["language_model"])

View File

@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -36,8 +35,6 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
logger = init_logger(__name__)
# These configs are not part of the model config but the preprocessor # These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now. # and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512

View File

@ -391,6 +391,7 @@ class CLIPVisionModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
@ -400,10 +401,6 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override)
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values) return self.vision_model(pixel_values)
@ -425,12 +422,12 @@ class CLIPVisionModel(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel # post_layernorm is not needed in CLIPVisionModel
if ("vision_model.post_layernorm" in name if (name.startswith("vision_model.post_layernorm")
and not self._require_post_layernorm): and self.vision_model.post_layernorm is None):
continue continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name: if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3]) layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue

View File

@ -28,7 +28,6 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -45,8 +44,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019 _NEWLINE_TOKEN_ID = 71019

View File

@ -12,7 +12,6 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -32,13 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

View File

@ -11,7 +11,6 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
@ -32,8 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 32 _MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1 _MAX_NUM_VIDEOS = 1

View File

@ -37,7 +37,6 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -59,8 +58,6 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
"llm.model": "llm", "llm.model": "llm",

View File

@ -501,6 +501,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): ):
super().__init__() super().__init__()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
@ -511,10 +512,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding return self.vision_model.embeddings.patch_embedding
@ -540,12 +537,12 @@ class SiglipVisionModel(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel # post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name if (name.startswith("vision_model.post_layernorm")
and not self._require_post_layernorm): and self.vision_model.post_layernorm is None):
continue continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name: if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3]) layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue

View File

@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -43,8 +42,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]