[Model] Add TP and BNB quantization support to LlavaMultiModalProjector (#10834)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2024-12-03 07:06:09 +08:00 committed by GitHub
parent 9b14d978aa
commit 4c05edb33a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 15 deletions

View File

@ -1120,7 +1120,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_config.revision,
pre_quant, load_8bit))
model.load_weights(qweight_iterator)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
torch.cuda.empty_cache()
@ -1152,9 +1159,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
shard_name, weight_name)
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model.")
continue
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}

View File

@ -13,6 +13,8 @@ from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -59,25 +61,32 @@ class LlavaImageEmbeddingInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states, _ = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
@ -325,7 +334,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
projector_hidden_act=config.projector_hidden_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,