mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:35:01 +08:00
[Misc] Modify BNB parameter name (#9997)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
d2e80332a7
commit
b9c64c0ca7
@ -203,8 +203,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
qweight = create_qweight_for_4bit()
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
# Enable parameters to have the same name as in the BNB
|
||||
# checkpoint format.
|
||||
layer.register_parameter("weight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
@ -234,7 +235,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
qweight = layer.weight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
@ -313,7 +314,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
qweight = layer.weight
|
||||
quant_states = qweight.bnb_quant_state
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ class BaseResampler(nn.Module):
|
||||
embed_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
prefix=f"{prefix}.kv_proj")
|
||||
else:
|
||||
# Maintain the same return value with ReplicatedLinear.forward
|
||||
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
||||
|
||||
@ -892,7 +892,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if not weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = weight_name.lower().replace(".scb", ".qweight")
|
||||
weight_key = weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
@ -901,11 +901,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if self._is_8bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||
|
||||
if qweight_name in quant_state_dict:
|
||||
if weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield qweight_name, weight_tensor
|
||||
yield weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
@ -950,9 +948,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name.replace(".weight", ".qweight"), weight_tensor
|
||||
yield weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
@ -967,7 +964,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
if any(target_module in weight_name for target_module in
|
||||
self.target_modules) and weight_name.endswith(".weight"):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
# Without sharding
|
||||
if any(
|
||||
weight_name.startswith(module)
|
||||
@ -1093,7 +1089,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
|
||||
shard_index = index
|
||||
quant_param_name = quant_param_name.replace(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user