mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 07:05:16 +08:00
Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023)
This commit is contained in:
parent
6c4dbe23eb
commit
2b25b7d2e1
@ -335,6 +335,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
@ -343,13 +349,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert final_shape[output_dim] % tp_size == 0
|
||||
final_shape[output_dim] = final_shape[output_dim] // tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user