726 lines
26 KiB
Python

from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of "
f"the number of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = (config.hidden_size //
config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
kv_hidden_size = config.encoder_hidden_size
else:
kv_hidden_size = config.hidden_size
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
if self.position_embedding_type != "absolute":
raise NotImplementedError("Unsupported position_embedding_type: "
f"{self.position_embedding_type}")
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
x = x.view(*x.size()[:-1], self.num_attention_heads,
self.attention_head_size)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
):
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_probs = torch.softmax(attention_scores * self.scaling,
dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(*context_layer.size()[:-2],
self.all_head_size)
return context_layer
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
) -> None:
super().__init__()
self.attention = Blip2QFormerMultiHeadAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
)
self.output = Blip2QFormerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
self_output = self.attention(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
attention_output = self.output(self_output, hidden_states)
return attention_output
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Blip2QFormerLayer(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
):
attention_output = self.attention(hidden_states)
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
query_attention_output = self.crossattention(
query_attention_output,
encoder_hidden_states=encoder_hidden_states,
)
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text],
dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
return layer_output
def feed_forward_chunk(self,
attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(
self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class Blip2QFormerEncoder(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
query_length: int,
) -> torch.Tensor:
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
hidden_states = layer_module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return hidden_states
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
class Blip2QFormerModel(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
) -> None:
super().__init__()
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config)
def forward(
self,
query_embeds: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
) -> torch.Tensor:
query_length = query_embeds.shape[1]
embedding_output = self.layernorm(query_embeds)
embedding_output = self.dropout(embedding_output)
sequence_output = self.encoder(
embedding_output,
encoder_hidden_states=encoder_hidden_states,
query_length=query_length,
)
return sequence_output
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
if isinstance(vision_config, Blip2VisionConfig):
return get_max_blip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += llm_inputs["prompt_token_ids"]
new_prompt = llm_inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Blip2Config,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.quant_config = quant_config
self.language_model = OPTModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
self.sampler = Sampler()
def get_lm_head(self):
return self.language_model.decoder.embed_tokens
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_model(pixel_values)
return image_features
def _process_image_pixels(self,
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
-1)
query_output = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_features,
)
return self.language_projection(query_output)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
`"Question: What's the content of the image? Answer:"`.
Tokenizer outputs:
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
dummy tokens (denoted as `50265`), resulting in:
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
We insert 32 tokens since it corresponds to the number of query
embeddings outputted by the Q-Former and inputted to the language model.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`Blip2ImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)