mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 15:38:45 +08:00
[Model] Add TP and BNB quantization support to LlavaMultiModalProjector (#10834)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
9b14d978aa
commit
4c05edb33a
@ -1120,7 +1120,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
model_config.revision,
|
||||
pre_quant, load_8bit))
|
||||
|
||||
model.load_weights(qweight_iterator)
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -1152,9 +1159,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
shard_name, weight_name)
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
raise ValueError(
|
||||
f"Parameter {quant_param_name} not found in the model.")
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -59,25 +61,32 @@ class LlavaImageEmbeddingInputs(TypedDict):
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||
|
||||
|
||||
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
|
||||
projector_hidden_act: str):
|
||||
def __init__(self,
|
||||
vision_hidden_size: int,
|
||||
text_hidden_size: int,
|
||||
projector_hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(vision_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=True)
|
||||
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_1")
|
||||
self.act = get_act_fn(projector_hidden_act)
|
||||
self.linear_2 = nn.Linear(text_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=True)
|
||||
self.linear_2 = RowParallelLinear(text_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_2")
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states, _ = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states, _ = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -325,7 +334,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user