mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 05:33:44 +08:00
[Misc] Add BNB quantization for Whisper (#12381)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c36ac98d01
commit
96b23621c1
@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
iterator = pt_weights_iterator(hf_weights_files)
|
||||
for name, param in iterator:
|
||||
# mapping weight names from transformers to vllm.
|
||||
yield self.weight_mapper(name), param
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if not weight_name.lower().endswith(".scb"):
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = weight_name.lower().replace(".scb", ".weight")
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(weight_name):
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if weight_name in quant_state_dict:
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield weight_name, weight_tensor
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if not self._is_4bit_weight_name(weight_name):
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in weight_name:
|
||||
temp_state_dict[weight_name] = weight_tensor.cpu().data
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[weight_name] = weight_tensor
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(weight_name):
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
|
||||
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
|
||||
in temp_state_dict) or (
|
||||
f"{weight_name}.quant_state.bitsandbytes__fp4"
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name, weight_tensor
|
||||
quant_state = _parse_quant_state(mapped_weight_name,
|
||||
temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name for target_module in
|
||||
self.target_modules) and weight_name.endswith(".weight"):
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if any(target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
weight_name.startswith(module)
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.unsharded_weights_modules):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
weight_name.startswith(module)
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.column_sharded_weights_modules):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
weight_name.startswith(module)
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.maybe_fused_weights_modules):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(sizes for module, sizes in
|
||||
self.maybe_fused_weights_modules.items()
|
||||
if weight_name.startswith(module)))
|
||||
if mapped_weight_name.startswith(module)))
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
|
||||
yield weight_name, processed_weight
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (LinearBase, )):
|
||||
last_name = name.split(".")[-1]
|
||||
if sub_modules := self.modules_mapping.packed_mapping.get(
|
||||
last_name, []):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
self.target_modules.append(
|
||||
name.replace(last_name, sub_name))
|
||||
name.replace(rep_name, sub_name))
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-splitted
|
||||
# weights with same last name.
|
||||
|
||||
@ -131,3 +131,10 @@ class ParamMapping:
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[Tuple[str, List[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
@ -638,6 +638,19 @@ def input_mapper_for_whisper(
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_whisper_audio_tokens)
|
||||
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
".fc1.": ".mlp.fc1.",
|
||||
".fc2.": ".mlp.fc2."
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
return loader.load_weights(weights, mapper=mapper)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user