[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (#12237)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-01-21 15:49:08 +08:00 committed by GitHub
parent 96912550c8
commit 1f1542afa9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 5 deletions

View File

@ -136,7 +136,18 @@ class PaliGemmaMultiModalProjector(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

View File

@ -344,10 +344,16 @@ class SiglipMLP(nn.Module):
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# For quantization, we require the hidden size to be a multiple of 64
quantizable = (config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0)
# Special handling for BNB quantization
if quant_config and quant_config.get_name() == "bitsandbytes":
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,