diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 809af81d707a8..19e3bc6a259e9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): iterator = safetensors_weights_iterator(hf_weights_files) else: iterator = pt_weights_iterator(hf_weights_files) - for name, param in iterator: - # mapping weight names from transformers to vllm. - yield self.weight_mapper(name), param + for org_name, param in iterator: + # mapping weight names from transformers to vllm while preserving + # original names. + mapped_name = self.weight_mapper(org_name) + yield org_name, mapped_name, param def _get_quantized_weights_iterator( self, @@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader): def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if not weight_name.lower().endswith(".scb"): + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if not mapped_weight_name.lower().endswith(".scb"): continue - weight_key = weight_name.lower().replace(".scb", ".weight") + weight_key = mapped_weight_name.lower().replace(".scb", ".weight") quant_state_dict[weight_key] = weight_tensor - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if self._is_8bit_weight_name(weight_name): + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_8bit_weight_name(mapped_weight_name): continue - if weight_name in quant_state_dict: + if mapped_weight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield weight_name, weight_tensor + yield org_weight_name, weight_tensor else: - yield weight_name, weight_tensor + yield org_weight_name, weight_tensor def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: @@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader): weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} - for weight_name, weight_tensor in weight_iterator: - if not self._is_4bit_weight_name(weight_name): + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in weight_iterator: + if not self._is_4bit_weight_name(mapped_weight_name): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU - if "quant_state.bitsandbytes" in weight_name: - temp_state_dict[weight_name] = weight_tensor.cpu().data + if "quant_state.bitsandbytes" in mapped_weight_name: + temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data else: - temp_state_dict[weight_name] = weight_tensor + temp_state_dict[mapped_weight_name] = weight_tensor # Closure to parse quant_state for each prequant weight def _parse_quant_state(param_name: str, @@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Second iterate over all prequant and normal weights # pre quantized weights would have a quant_state - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if self._is_4bit_weight_name(weight_name): + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_4bit_weight_name(mapped_weight_name): continue - if (f"{weight_name}.quant_state.bitsandbytes__nf4" + if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( - f"{weight_name}.quant_state.bitsandbytes__fp4" + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict): - quant_state = _parse_quant_state(weight_name, temp_state_dict) - quant_state_dict[weight_name] = quant_state - yield weight_name, weight_tensor + quant_state = _parse_quant_state(mapped_weight_name, + temp_state_dict) + quant_state_dict[mapped_weight_name] = quant_state + yield org_weight_name, weight_tensor else: - yield weight_name, weight_tensor + yield org_weight_name, weight_tensor def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: @@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader): tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if any(target_module in weight_name for target_module in - self.target_modules) and weight_name.endswith(".weight"): + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if any(target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - weight_name.startswith(module) + mapped_weight_name.startswith(module) for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column elif any( - weight_name.startswith(module) + mapped_weight_name.startswith(module) for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank @@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - weight_name.startswith(module) + mapped_weight_name.startswith(module) for module in self.maybe_fused_weights_modules): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( (sizes for module, sizes in self.maybe_fused_weights_modules.items() - if weight_name.startswith(module))) + if mapped_weight_name.startswith(module))) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor @@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_type="nf4", ) - quant_state_dict[weight_name] = quant_state + quant_state_dict[mapped_weight_name] = quant_state else: processed_weight = weight_tensor - - yield weight_name, processed_weight + yield org_weight_name, processed_weight def _get_bnb_target_modules(self, model: nn.Module) -> None: for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): - last_name = name.split(".")[-1] - if sub_modules := self.modules_mapping.packed_mapping.get( - last_name, []): + if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. + rep_name, sub_modules = modules_info for sub_name in sub_modules: self.target_modules.append( - name.replace(last_name, sub_name)) + name.replace(rep_name, sub_name)) # Add original module name even if the module has stacked map, # in case model has a mixture of disk-merged and disk-splitted # weights with same last name. diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index eb334c1fdf255..7a82a695c5070 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -131,3 +131,10 @@ class ParamMapping: packed_name, index, ) + + def get_sub_modules(self, + module_name: str) -> Optional[Tuple[str, List[str]]]: + for key, value in self.packed_mapping.items(): + if module_name.endswith(key): + return key, value + return None diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 2319c31609308..0a3011d361013 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -638,6 +638,19 @@ def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_whisper_audio_tokens) class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ + ".fc1.": ".mlp.fc1.", + ".fc2.": ".mlp.fc2." + }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) - mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) + # add fake zeros bias for k_proj to state_dict weights = _create_fake_bias_for_k_proj(weights) - return loader.load_weights(weights, mapper=mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def _create_fake_bias_for_k_proj(