[Misc]Add BNB quantization for MolmoForCausalLM (#11551)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-28 02:48:24 +08:00 committed by GitHub
parent 55509c2114
commit 0240402c46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 33 deletions

View File

@ -11,7 +11,8 @@ import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, cast)
import gguf
import huggingface_hub
@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
def _get_weight_files(
self,
@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
iterator = safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param
def _get_quantized_weights_iterator(
self,
@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.44.0":
if bitsandbytes.__version__ < "0.45.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0.")
"install bitsandbytes>=0.45.0.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.44.0` to use "
raise ImportError("Please install bitsandbytes>=0.45.0 via "
"`pip install bitsandbytes>=0.45.0` to use "
"bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights(
@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
# Map vllm's names to transformers's names.
for sub_name in sub_modules:
self.target_modules.append(
name.replace(last_name, sub_name))
@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}

View File

@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
return output
class MolmoMLP(nn.Module):
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)
class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""
def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
proj_name: str = "gate_up_proj") -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
# Molmo's LLM proj weights are already merged into the disk, while
# image_projector proj is separate. If the same proj_name were used, it
# would create ambiguity and make it difficult to support BNB and LoRA.
self.proj_name = proj_name
setattr(
self, proj_name,
MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
self.gate_up_proj = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class ImageProjectorMLP(nn.Module):
"""Molmo's image_projector mlp."""
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
self.merged_linear = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SiluAndMul()
@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = getattr(self, self.proj_name)(x)
gate_up, _ = self.merged_linear(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")
self.mlp = LanuageModelMLP(config, quant_config=quant_config)
# LayerNorm
assert config.layer_norm_type == "rms"
@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config)
self.image_projector = MolmoMLP(
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
proj_name="merged_linear",
)
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
)
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config