[Model] Add BNB quantization support for Idefics3 (#10310)

Signed-off-by: B-201 <Joy25810@foxmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
B-201 2024-11-14 14:31:44 +08:00 committed by GitHub
parent 52b48c1ead
commit 294bf467ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@ import torch.utils.checkpoint
from PIL import Image from PIL import Image
from torch import nn from torch import nn
# Temporary solution for transformers below 4.46.0. # Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -31,6 +32,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -374,12 +376,23 @@ def dummy_data_for_idefics3(
class Idefics3SimpleMLP(nn.Module): class Idefics3SimpleMLP(nn.Module):
def __init__(self, config): def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor** input_size = config.vision_config.hidden_size * (config.scale_factor**
2) 2)
output_size = config.text_config.hidden_size output_size = config.text_config.hidden_size
self.proj = ReplicatedLinear(input_size, output_size, bias=False) self.proj = ReplicatedLinear(
input_size,
output_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj"),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
out, _ = self.proj(x) out, _ = self.proj(x)
@ -388,10 +401,19 @@ class Idefics3SimpleMLP(nn.Module):
class Idefics3Connector(nn.Module): class Idefics3Connector(nn.Module):
def __init__(self, config): def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
self.scale_factor = config.scale_factor self.scale_factor = config.scale_factor
self.modality_projection = Idefics3SimpleMLP(config) self.modality_projection = Idefics3SimpleMLP(
config,
quant_config,
prefix=maybe_prefix(prefix, "modality_projection"),
)
def pixel_shuffle(self, def pixel_shuffle(self,
x: torch.Tensor, x: torch.Tensor,
@ -431,9 +453,15 @@ class Idefics3Model(nn.Module):
self.config = config self.config = config
self.padding_idx = self.config.text_config.pad_token_id self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics3VisionTransformer(config.vision_config, self.vision_model = Idefics3VisionTransformer(
quant_config) config.vision_config,
self.connector = Idefics3Connector(config) quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"))
self.connector = Idefics3Connector(
config,
quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
self.text_model = LlamaModel( self.text_model = LlamaModel(
vllm_config=vllm_config.with_hf_config(config.text_config), vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "text_model"), prefix=maybe_prefix(prefix, "text_model"),
@ -637,6 +665,32 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []