mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Bugfix] Fix BNB name match (#24735)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
64d90c3e4f
commit
60a0951924
@ -326,7 +326,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
global_tp_size = get_tensor_model_parallel_world_size()
|
global_tp_size = get_tensor_model_parallel_world_size()
|
||||||
global_tp_rank = get_tensor_model_parallel_rank()
|
global_tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
check_match = (lambda weight_name, module_name: weight_name.
|
||||||
|
removesuffix(".weight") == module_name)
|
||||||
for (
|
for (
|
||||||
org_weight_name,
|
org_weight_name,
|
||||||
mapped_weight_name,
|
mapped_weight_name,
|
||||||
@ -347,12 +348,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
) and mapped_weight_name.endswith(".weight"):
|
) and mapped_weight_name.endswith(".weight"):
|
||||||
# Without sharding
|
# Without sharding
|
||||||
if any(
|
if any(
|
||||||
mapped_weight_name.startswith(module)
|
check_match(mapped_weight_name, module)
|
||||||
for module in self.unsharded_weights_modules):
|
for module in self.unsharded_weights_modules):
|
||||||
weight_sub_tensor = weight_tensor
|
weight_sub_tensor = weight_tensor
|
||||||
# Shard by column
|
# Shard by column
|
||||||
elif any(
|
elif any(
|
||||||
mapped_weight_name.startswith(module)
|
check_match(mapped_weight_name, module)
|
||||||
for module in self.column_sharded_weights_modules):
|
for module in self.column_sharded_weights_modules):
|
||||||
total_size = weight_tensor.size(-1)
|
total_size = weight_tensor.size(-1)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
@ -362,14 +363,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Weights have fused on disk. In this case, we assume that the
|
# Weights have fused on disk. In this case, we assume that the
|
||||||
# weight and module use same name.
|
# weight and module use same name.
|
||||||
elif any(
|
elif any(
|
||||||
mapped_weight_name.startswith(module)
|
check_match(mapped_weight_name, module)
|
||||||
for module in self.maybe_fused_weights_modules):
|
for module in self.maybe_fused_weights_modules):
|
||||||
# special case for fused weights
|
# special case for fused weights
|
||||||
# get the size of each shard weight tensor
|
# get the size of each shard weight tensor
|
||||||
total_shard_sizes = next(
|
total_shard_sizes = next(
|
||||||
(sizes for module, sizes in
|
(sizes for module, sizes in
|
||||||
self.maybe_fused_weights_modules.items()
|
self.maybe_fused_weights_modules.items()
|
||||||
if mapped_weight_name.startswith(module)))
|
if check_match(mapped_weight_name, module)))
|
||||||
total_size = weight_tensor.size(0)
|
total_size = weight_tensor.size(0)
|
||||||
assert total_size == sum(total_shard_sizes)
|
assert total_size == sum(total_shard_sizes)
|
||||||
# get the start/end index of each shard weight tensor
|
# get the start/end index of each shard weight tensor
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user