mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:15:28 +08:00
[FIX] Don't initialize parameter by default (#1067)
This commit is contained in:
parent
e21d7687a9
commit
90979c38f8
@ -83,7 +83,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
init_method=init.xavier_normal_,
|
init_method=init.xavier_normal_,
|
||||||
params_dtype: torch.dtype=None,
|
params_dtype: torch.dtype=None,
|
||||||
use_cpu_initialization: bool=False,
|
use_cpu_initialization: bool=False,
|
||||||
perform_initialization: bool=True):
|
perform_initialization: bool=False):
|
||||||
super(VocabParallelEmbedding, self).__init__()
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
assert not perform_initialization
|
assert not perform_initialization
|
||||||
assert not use_cpu_initialization
|
assert not use_cpu_initialization
|
||||||
@ -113,7 +113,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.weight = Parameter(torch.empty(
|
self.weight = Parameter(torch.empty(
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
self.num_embeddings_per_partition, self.embedding_dim,
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.tensor_model_parallel_size > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
@ -172,7 +172,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
@ -288,7 +288,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user