diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 2147b16b8a49..caa0e3194f3d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,6 +6,10 @@ import torch.nn as nn from vllm import activation_ops from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import divide +from vllm.model_executor.utils import set_weight_attrs class SiluAndMul(nn.Module): @@ -51,17 +55,38 @@ class ScaledActivation(nn.Module): def __init__( self, act_module: nn.Module, - hidden_size: int, - params_dtype: torch.dtype, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, ): super().__init__() self.act = act_module + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, + tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() self.scales = nn.Parameter( - torch.empty(hidden_size, dtype=params_dtype, device="cuda")) + torch.empty(intermediate_size_per_partition, + dtype=params_dtype, + device="cuda")) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(x) / self.scales + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + _ACTIVATION_REGISTRY = { "gelu": nn.GELU(), @@ -76,6 +101,8 @@ def get_act_fn( act_fn_name: str, quant_config: Optional[QuantizationConfig] = None, intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, ) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() @@ -84,14 +111,11 @@ def get_act_fn( f"Activation function {act_fn_name!r} is not supported.") act_fn = _ACTIVATION_REGISTRY[act_fn_name] - if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names( - ): + if (quant_config is not None + and act_fn_name in quant_config.get_scaled_act_names()): if intermediate_size is None: raise ValueError("intermediate_size must be specified for scaled " "activation functions.") - return ScaledActivation( - act_fn, - intermediate_size, - params_dtype=torch.get_default_dtype(), - ) + return ScaledActivation(act_fn, intermediate_size, input_is_parallel, + params_dtype) return act_fn diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 47b991e58602..4c8ff596b473 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module): linear_method=linear_method, ) self.do_layer_norm_before = config.do_layer_norm_before - quant_config = getattr(linear_method, "quant_config", None) - self.activation_fn = get_act_fn(config.activation_function, - quant_config, config.ffn_dim) self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, @@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module): bias=config.enable_bias, linear_method=linear_method, ) + quant_config = getattr(linear_method, "quant_config", None) + self.activation_fn = get_act_fn(config.activation_function, + quant_config, config.ffn_dim) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim,