mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[Bugfix] Fix BNB name match (#24735)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
64d90c3e4f
commit
60a0951924
@ -326,7 +326,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
check_match = (lambda weight_name, module_name: weight_name.
|
||||
removesuffix(".weight") == module_name)
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
@ -347,12 +348,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
mapped_weight_name.startswith(module)
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.unsharded_weights_modules):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
mapped_weight_name.startswith(module)
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.column_sharded_weights_modules):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
@ -362,14 +363,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
mapped_weight_name.startswith(module)
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.maybe_fused_weights_modules):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(sizes for module, sizes in
|
||||
self.maybe_fused_weights_modules.items()
|
||||
if mapped_weight_name.startswith(module)))
|
||||
if check_match(mapped_weight_name, module)))
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user