mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[FIX] Fix the case when input_is_parallel=False for ScaledActivation (#1737)
This commit is contained in:
parent
cf35d8f3d7
commit
7d761fe3c1
@ -61,6 +61,7 @@ class ScaledActivation(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = act_module
|
self.act = act_module
|
||||||
|
self.input_is_parallel = input_is_parallel
|
||||||
if input_is_parallel:
|
if input_is_parallel:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
intermediate_size_per_partition = divide(intermediate_size,
|
intermediate_size_per_partition = divide(intermediate_size,
|
||||||
@ -79,11 +80,12 @@ class ScaledActivation(nn.Module):
|
|||||||
return self.act(x) / self.scales
|
return self.act(x) / self.scales
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
shard_size = param_data.shape[0]
|
if self.input_is_parallel:
|
||||||
start_idx = tp_rank * shard_size
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
shard_size = param_data.shape[0]
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user