mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 02:45:02 +08:00
[Misc]Add BNB quantization for MolmoForCausalLM (#11551)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
55509c2114
commit
0240402c46
@ -11,7 +11,8 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
|
||||||
|
Tuple, cast)
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Store all module names (from transformers) that support
|
# Store all module names (from transformers) that support
|
||||||
# BNB quantization.
|
# BNB quantization.
|
||||||
self.target_modules: List[str] = []
|
self.target_modules: List[str] = []
|
||||||
|
# mapping weight names from transformers to vllm.
|
||||||
|
self.weight_mapper: Callable = lambda name: name
|
||||||
|
|
||||||
def _get_weight_files(
|
def _get_weight_files(
|
||||||
self,
|
self,
|
||||||
@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
return safetensors_weights_iterator(hf_weights_files)
|
iterator = safetensors_weights_iterator(hf_weights_files)
|
||||||
else:
|
else:
|
||||||
return pt_weights_iterator(hf_weights_files)
|
iterator = pt_weights_iterator(hf_weights_files)
|
||||||
|
for name, param in iterator:
|
||||||
|
# mapping weight names from transformers to vllm.
|
||||||
|
yield self.weight_mapper(name), param
|
||||||
|
|
||||||
def _get_quantized_weights_iterator(
|
def _get_quantized_weights_iterator(
|
||||||
self,
|
self,
|
||||||
@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
try:
|
try:
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
|
||||||
if bitsandbytes.__version__ < "0.44.0":
|
if bitsandbytes.__version__ < "0.45.0":
|
||||||
raise ImportError("bitsandbytes version is wrong. Please "
|
raise ImportError("bitsandbytes version is wrong. Please "
|
||||||
"install bitsandbytes>=0.44.0.")
|
"install bitsandbytes>=0.45.0.")
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
raise ImportError("Please install bitsandbytes>=0.44.0 via "
|
raise ImportError("Please install bitsandbytes>=0.45.0 via "
|
||||||
"`pip install bitsandbytes>=0.44.0` to use "
|
"`pip install bitsandbytes>=0.45.0` to use "
|
||||||
"bitsandbytes quantizer.") from err
|
"bitsandbytes quantizer.") from err
|
||||||
|
|
||||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if isinstance(module, (LinearBase, )):
|
if isinstance(module, (LinearBase, )):
|
||||||
last_name = name.split(".")[-1]
|
last_name = name.split(".")[-1]
|
||||||
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||||
# Map vllm's names to transformers' names.
|
# Map vllm's names to transformers's names.
|
||||||
for sub_name in sub_modules:
|
for sub_name in sub_modules:
|
||||||
self.target_modules.append(
|
self.target_modules.append(
|
||||||
name.replace(last_name, sub_name))
|
name.replace(last_name, sub_name))
|
||||||
@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||||
"quantization yet.")
|
"quantization yet.")
|
||||||
|
|
||||||
|
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||||
|
# to ensure correct loading of weights.
|
||||||
|
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||||
|
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||||
# Modules whose weights might have fused on disk
|
# Modules whose weights might have fused on disk
|
||||||
# we need their output_sizes to make shard in flight correctly with TP
|
# we need their output_sizes to make shard in flight correctly with TP
|
||||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||||
|
|||||||
@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MolmoMLP(nn.Module):
|
class SwiGLU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
# Note that the order is reversed compared to
|
||||||
|
# SiluAndMul.
|
||||||
|
return x * F.silu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class LanuageModelMLP(nn.Module):
|
||||||
"""Molmo's LLM mlp."""
|
"""Molmo's LLM mlp."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
input_dim: Optional[int] = None,
|
input_dim: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
proj_name: str = "gate_up_proj") -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.intermediate_size = config.intermediate_size // 2
|
self.intermediate_size = config.intermediate_size // 2
|
||||||
|
|
||||||
# Molmo's LLM proj weights are already merged into the disk, while
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
# image_projector proj is separate. If the same proj_name were used, it
|
|
||||||
# would create ambiguity and make it difficult to support BNB and LoRA.
|
|
||||||
self.proj_name = proj_name
|
|
||||||
setattr(
|
|
||||||
self, proj_name,
|
|
||||||
MergedColumnParallelLinear(
|
|
||||||
input_dim or self.hidden_size,
|
input_dim or self.hidden_size,
|
||||||
[self.intermediate_size] * 2,
|
[self.intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
))
|
)
|
||||||
|
# Activation function.
|
||||||
|
self.act_fn = SwiGLU()
|
||||||
|
# Feed-forward output projection.
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProjectorMLP(nn.Module):
|
||||||
|
"""Molmo's image_projector mlp."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
input_dim: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size // 2
|
||||||
|
|
||||||
|
self.merged_linear = MergedColumnParallelLinear(
|
||||||
|
input_dim or self.hidden_size,
|
||||||
|
[self.intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
# Activation function.
|
# Activation function.
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
gate_up, _ = getattr(self, self.proj_name)(x)
|
gate_up, _ = self.merged_linear(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x)
|
||||||
return x
|
return x
|
||||||
@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
|
|||||||
prefix=f"{prefix}.self_attn")
|
prefix=f"{prefix}.self_attn")
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
self.mlp = MolmoMLP(config,
|
self.mlp = LanuageModelMLP(config, quant_config=quant_config)
|
||||||
quant_config=quant_config,
|
|
||||||
proj_name="gate_up_proj")
|
|
||||||
|
|
||||||
# LayerNorm
|
# LayerNorm
|
||||||
assert config.layer_norm_type == "rms"
|
assert config.layer_norm_type == "rms"
|
||||||
@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
|
|||||||
vision_config,
|
vision_config,
|
||||||
nlayers=len(self.vit_layers),
|
nlayers=len(self.vit_layers),
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.image_projector = MolmoMLP(
|
self.image_projector = ImageProjectorMLP(
|
||||||
config,
|
config,
|
||||||
input_dim=vision_config.image_emb_dim,
|
input_dim=vision_config.image_emb_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
proj_name="merged_linear",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
||||||
@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
|
|||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "gate_up_proj" in name:
|
|
||||||
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
|
|
||||||
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
|
|
||||||
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
"gate_proj": ("merged_linear", 0),
|
||||||
|
"up_proj": ("merged_linear", 1),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user