[Bugfix] weight loading use correct tp_group with patch_tensor_parallel_group (#21024)

Signed-off-by: KevinXiong-C <kevin_xiong1997@outlook.com>
This commit is contained in:
Kevin_Xiong 2025-07-17 10:36:36 +08:00 committed by GitHub
parent 4e7dfbe7b4
commit c9ba8104ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -452,8 +452,10 @@ class ColumnParallelLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.tp_rank = get_tensor_model_parallel_rank()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
@ -472,15 +474,15 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if output_dim is not None:
tp_size = get_tensor_model_parallel_world_size()
assert final_shape[output_dim] % tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // tp_size
assert final_shape[output_dim] % self.tp_size == 0
final_shape[output_dim] = (final_shape[output_dim] //
self.tp_size)
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
@ -565,8 +567,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: bool = True,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
@ -598,12 +603,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@ -669,11 +672,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
self.tp_size)
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
@ -701,7 +703,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
@ -991,12 +993,9 @@ class QKVParallelLinear(ColumnParallelLinear):
return
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@ -1071,7 +1070,6 @@ class QKVParallelLinear(ColumnParallelLinear):
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
@ -1123,9 +1121,9 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
shard_id = self.tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
if not is_sharded_weight:
@ -1245,8 +1243,6 @@ class RowParallelLinear(LinearBase):
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
@ -1264,13 +1260,14 @@ class RowParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter):
weight_shape = list(loaded_weight.shape)
if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
weight_shape[input_dim] = (weight_shape[input_dim] //
self.tp_size)
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)