[Bugfix] Fix BNB name match (#24735)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-09-12 19:12:01 +08:00 committed by GitHub
parent 64d90c3e4f
commit 60a0951924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -326,7 +326,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
check_match = (lambda weight_name, module_name: weight_name.
removesuffix(".weight") == module_name)
for (
org_weight_name,
mapped_weight_name,
@ -347,12 +348,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
@ -362,14 +363,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if mapped_weight_name.startswith(module)))
if check_match(mapped_weight_name, module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor