[Model] Port over CLIPVisionModel for VLMs (#5591)

This commit is contained in:
Roger Wang 2024-06-20 04:52:09 -07:00 committed by GitHub
parent 111af1fa2c
commit ad137cd111
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 269 additions and 21 deletions

View File

@ -135,6 +135,12 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
// x * sigmoid(1.702 * x)
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}
} // namespace vllm
void gelu_new(torch::Tensor& out, // [..., d]
@ -148,3 +154,9 @@ void gelu_fast(torch::Tensor& out, // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}

View File

@ -49,6 +49,8 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,

View File

@ -68,6 +68,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(

View File

@ -66,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,

View File

@ -141,6 +141,21 @@ class FastGELU(CustomOp):
return out
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
@ -189,6 +204,7 @@ _ACTIVATION_REGISTRY = {
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"quick_gelu": QuickGELU(),
}

View File

@ -0,0 +1,203 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
def get_clip_num_patches(image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return (image_size // patch_size)**2
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = get_clip_num_patches(self.image_size,
self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPMLP(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
def forward(self,
inputs_embeds: torch.Tensor,
vision_feature_layer: int = -1):
# Encoder forward pass only up to the required layer
num_layer = len(self.layers) + vision_feature_layer + 1
hidden_states = inputs_embeds
for encoder_layer in self.layers[:num_layer]:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class CLIPVisionTransformer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
def forward(
self,
pixel_values: torch.Tensor,
vision_feature_layer: int = -1,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states,
vision_feature_layer=vision_feature_layer)
return hidden_states
class CLIPVisionModel(nn.Module):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vision_model = CLIPVisionTransformer(config=config,
quant_config=quant_config)
def forward(self,
pixel_values: Optional[torch.Tensor] = None,
vision_feature_layer: int = -1):
return self.vision_model(pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer)
@property
def device(self):
return next(self.parameters()).device

View File

@ -2,9 +2,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from transformers import LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
@ -15,6 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -189,12 +188,11 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)
return self._select_image_features(
image_features,
@ -317,6 +315,9 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)

View File

@ -4,9 +4,7 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
import torch
import torch.nn as nn
from PIL import Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaNextConfig
from transformers import LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
@ -20,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
@ -121,7 +120,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
@ -219,12 +218,11 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)
return self._select_image_features(
image_features,
@ -430,6 +428,9 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)

View File

@ -17,7 +17,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -70,9 +71,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
img_processor_output = self.img_processor(img_embeds,
output_hidden_states=True)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds,
vision_feature_layer=LAYER_IDX)
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
@ -352,6 +354,9 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)