mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:35:00 +08:00
[Misc] Modify BNB parameter name (#9997)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
d2e80332a7
commit
b9c64c0ca7
@ -203,8 +203,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
qweight = create_qweight_for_8bit()
|
qweight = create_qweight_for_8bit()
|
||||||
else:
|
else:
|
||||||
qweight = create_qweight_for_4bit()
|
qweight = create_qweight_for_4bit()
|
||||||
|
# Enable parameters to have the same name as in the BNB
|
||||||
layer.register_parameter("qweight", qweight)
|
# checkpoint format.
|
||||||
|
layer.register_parameter("weight", qweight)
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
set_weight_attrs(qweight, extra_weight_attrs)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
@ -234,7 +235,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
reshape_after_matmul = True
|
reshape_after_matmul = True
|
||||||
bf_x = x.to(torch.bfloat16)
|
bf_x = x.to(torch.bfloat16)
|
||||||
|
|
||||||
qweight = layer.qweight
|
qweight = layer.weight
|
||||||
offsets = qweight.bnb_shard_offsets
|
offsets = qweight.bnb_shard_offsets
|
||||||
quant_states = qweight.bnb_quant_state
|
quant_states = qweight.bnb_quant_state
|
||||||
matmul_states = qweight.matmul_state
|
matmul_states = qweight.matmul_state
|
||||||
@ -313,7 +314,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
reshape_after_matmul = True
|
reshape_after_matmul = True
|
||||||
bf_x = x.to(torch.bfloat16)
|
bf_x = x.to(torch.bfloat16)
|
||||||
|
|
||||||
qweight = layer.qweight
|
qweight = layer.weight
|
||||||
quant_states = qweight.bnb_quant_state
|
quant_states = qweight.bnb_quant_state
|
||||||
offsets = qweight.bnb_shard_offsets
|
offsets = qweight.bnb_shard_offsets
|
||||||
|
|
||||||
|
|||||||
@ -177,7 +177,7 @@ class BaseResampler(nn.Module):
|
|||||||
embed_dim,
|
embed_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=f"{prefix}.kv_proj")
|
||||||
else:
|
else:
|
||||||
# Maintain the same return value with ReplicatedLinear.forward
|
# Maintain the same return value with ReplicatedLinear.forward
|
||||||
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
||||||
|
|||||||
@ -892,7 +892,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if not weight_name.lower().endswith(".scb"):
|
if not weight_name.lower().endswith(".scb"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = weight_name.lower().replace(".scb", ".qweight")
|
weight_key = weight_name.lower().replace(".scb", ".weight")
|
||||||
quant_state_dict[weight_key] = weight_tensor
|
quant_state_dict[weight_key] = weight_tensor
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
@ -901,11 +901,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if self._is_8bit_weight_name(weight_name):
|
if self._is_8bit_weight_name(weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
if weight_name in quant_state_dict:
|
||||||
|
|
||||||
if qweight_name in quant_state_dict:
|
|
||||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||||
yield qweight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
|
|
||||||
@ -950,9 +948,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||||
in temp_state_dict):
|
in temp_state_dict):
|
||||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
|
||||||
quant_state_dict[weight_name] = quant_state
|
quant_state_dict[weight_name] = quant_state
|
||||||
yield weight_name.replace(".weight", ".qweight"), weight_tensor
|
yield weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
|
|
||||||
@ -967,7 +964,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
if any(target_module in weight_name for target_module in
|
if any(target_module in weight_name for target_module in
|
||||||
self.target_modules) and weight_name.endswith(".weight"):
|
self.target_modules) and weight_name.endswith(".weight"):
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
|
||||||
# Without sharding
|
# Without sharding
|
||||||
if any(
|
if any(
|
||||||
weight_name.startswith(module)
|
weight_name.startswith(module)
|
||||||
@ -1093,7 +1089,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||||
# from being incorrectly identified as being present in
|
# from being incorrectly identified as being present in
|
||||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
|
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||||
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
|
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
|
||||||
shard_index = index
|
shard_index = index
|
||||||
quant_param_name = quant_param_name.replace(
|
quant_param_name = quant_param_name.replace(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user