mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 20:04:35 +08:00
Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023)
This commit is contained in:
parent
6c4dbe23eb
commit
2b25b7d2e1
@ -335,6 +335,12 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
# bitsandbytes loads the weights of the specific portion
|
||||||
|
# no need to narrow
|
||||||
|
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||||
|
|
||||||
# Special case for GGUF
|
# Special case for GGUF
|
||||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||||
@ -343,13 +349,12 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
|
|
||||||
# Materialize GGUF UninitializedParameter
|
# Materialize GGUF UninitializedParameter
|
||||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
final_shape = list(loaded_weight.shape)
|
||||||
|
if output_dim is not None:
|
||||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
assert final_shape[output_dim] % tp_size == 0
|
||||||
# bitsandbytes loads the weights of the specific portion
|
final_shape[output_dim] = final_shape[output_dim] // tp_size
|
||||||
# no need to narrow
|
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
if output_dim is not None and not is_sharded_weight:
|
if output_dim is not None and not is_sharded_weight:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user