[Misc] Add BNB quantization for Whisper (#12381)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-02-04 16:27:36 +08:00 committed by GitHub
parent c36ac98d01
commit 96b23621c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 44 deletions

View File

@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
iterator = safetensors_weights_iterator(hf_weights_files)
else:
iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
self,
@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.lower().endswith(".scb"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
weight_key = weight_name.lower().replace(".scb", ".weight")
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
if weight_name in quant_state_dict:
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if not self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(weight_name):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4"
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
quant_state_dict[weight_name] = quant_state
yield weight_name, weight_tensor
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield weight_name, weight_tensor
yield org_weight_name, weight_tensor
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
weight_name.startswith(module)
mapped_weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if weight_name.startswith(module)))
if mapped_weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_type="nf4",
)
quant_state_dict[weight_name] = quant_state
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(last_name, sub_name))
name.replace(rep_name, sub_name))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.

View File

@ -131,3 +131,10 @@ class ParamMapping:
packed_name,
index,
)
def get_sub_modules(self,
module_name: str) -> Optional[Tuple[str, List[str]]]:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None

View File

@ -638,6 +638,19 @@ def input_mapper_for_whisper(
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2."
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(