Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023)

This commit is contained in:
Szymon Ożóg 2025-02-11 17:38:48 +01:00 committed by GitHub
parent 6c4dbe23eb
commit 2b25b7d2e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -335,6 +335,12 @@ class ColumnParallelLinear(LinearBase):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
@ -343,13 +349,12 @@ class ColumnParallelLinear(LinearBase):
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
final_shape = list(loaded_weight.shape)
if output_dim is not None:
tp_size = get_tensor_model_parallel_world_size()
assert final_shape[output_dim] % tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight: