mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:35:56 +08:00
[Bugfix] weight loading use correct tp_group with patch_tensor_parallel_group (#21024)
Signed-off-by: KevinXiong-C <kevin_xiong1997@outlook.com>
This commit is contained in:
parent
4e7dfbe7b4
commit
c9ba8104ed
@ -452,8 +452,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
@ -472,15 +474,15 @@ class ColumnParallelLinear(LinearBase):
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert final_shape[output_dim] % tp_size == 0
|
||||
final_shape[output_dim] = final_shape[output_dim] // tp_size
|
||||
assert final_shape[output_dim] % self.tp_size == 0
|
||||
final_shape[output_dim] = (final_shape[output_dim] //
|
||||
self.tp_size)
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
@ -565,8 +567,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
@ -598,12 +603,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
@ -669,11 +672,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
|
||||
self.tp_size)
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
@ -701,7 +703,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
@ -991,12 +993,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
@ -1071,7 +1070,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
@ -1123,9 +1121,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
shard_id = self.tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
shard_id = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
@ -1245,8 +1243,6 @@ class RowParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
@ -1264,13 +1260,14 @@ class RowParallelLinear(LinearBase):
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
weight_shape = list(loaded_weight.shape)
|
||||
if input_dim:
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
||||
weight_shape[input_dim] = (weight_shape[input_dim] //
|
||||
self.tp_size)
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user