diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 109451d4c8f98..2e774a1c7320f 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -1,8 +1,11 @@ """Custom activation functions.""" +from typing import Optional + import torch import torch.nn as nn from vllm import activation_ops +from vllm.model_executor.layers.quantization import QuantizationConfig class SiluAndMul(nn.Module): @@ -39,6 +42,27 @@ class FastGELU(nn.Module): return out +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + hidden_size: int, + params_dtype: torch.dtype, + ): + super().__init__() + self.act = act_module + self.scales = nn.Parameter( + torch.empty(hidden_size, dtype=params_dtype, device="cuda")) + + def forward(self, x: torch.Tensor): + return self.act(x) / self.scales + + _ACTIVATION_REGISTRY = { "gelu": nn.GELU(), "gelu_fast": FastGELU(), @@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = { } -def get_act_fn(act_fn: str) -> nn.Module: +def get_act_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, +) -> nn.Module: """Get an activation function by name.""" - act_fn = act_fn.lower() - if act_fn in _ACTIVATION_REGISTRY: - return _ACTIVATION_REGISTRY[act_fn] - raise ValueError(f"Activation function {act_fn!r} is not supported.") + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_REGISTRY[act_fn_name] + if quant_config is not None: + if act_fn_name in quant_config.get_scaled_act_names(): + if intermediate_size is None: + raise ValueError( + "intermediate_size must be specified for scaled " + "activation functions.") + return ScaledActivation( + act_fn, + intermediate_size, + params_dtype=torch.get_default_dtype(), + ) + return act_fn diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2a077b439e49d..44e572bdc12f5 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig): def get_linear_method(self) -> "AWQLinearMethod": return AWQLinearMethod(self) + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + class AWQLinearMethod(LinearMethodBase): """Linear method for AWQ. diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 116ff903c2290..6115e7c3be956 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -54,3 +54,11 @@ class QuantizationConfig(ABC): def get_linear_method(self) -> LinearMethodBase: """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError + + @abstractmethod + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index a85dd91be7dbd..61ec8b79b6ddc 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig): def get_linear_method(self) -> "SqueezeLLMLinearMethod": return SqueezeLLMLinearMethod(self) + def get_scaled_act_names(self) -> List[str]: + return [] + class SqueezeLLMLinearMethod(LinearMethodBase): """Linear method for SqueezeLLM. diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1d379a623c76d..6a5f8c516f317 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -145,7 +145,8 @@ class BloomMLP(nn.Module): 4 * hidden_size, linear_method=linear_method, ) - self.act = get_act_fn("gelu") + quant_config = getattr(linear_method, "quant_config", None) + self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, @@ -154,7 +155,7 @@ class BloomMLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.dense_h_to_4h(x) - x = self.act(x) + x = self.gelu_impl(x) x, _ = self.dense_4h_to_h(x) return x diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 3307d05494429..f1b5d1da3601a 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,7 @@ from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import (PagedAttention, PagedAttentionWithALiBi, PagedAttentionWithRoPE) @@ -131,6 +132,7 @@ class FalconAttention(nn.Module): self.hidden_size, bias=config.bias, skip_bias_add=True, + linear_method=linear_method, reduce_results=self.reduce_row_parallel_results) self.use_rotary = config.rotary @@ -206,7 +208,8 @@ class FalconMLP(nn.Module): bias=config.bias, skip_bias_add=True, linear_method=linear_method) - self.act = nn.GELU() + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) self.dense_4h_to_h = RowParallelLinear( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index d540f74724202..1de3d85e233ff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -118,7 +118,9 @@ class GPT2MLP(nn.Module): bias=True, linear_method=linear_method, ) - self.act = get_act_fn(config.activation_function) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 1e489e97052a7..c2f9611c0fef2 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module): bias=True, linear_method=linear_method, ) - self.act = get_act_fn(config.activation_function) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index a5b77138bd17f..a5bb6f0fbefc5 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -128,7 +128,9 @@ class GPTJMLP(nn.Module): hidden_size, linear_method=linear_method, ) - self.act = get_act_fn(config.activation_function) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc_in(hidden_states) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 5c40783262ce7..97ac5ca243557 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module): config.hidden_size, linear_method=linear_method, ) - self.act = get_act_fn(config.hidden_act) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.hidden_act, quant_config, + config.intermediate_size) def forward(self, hidden_states): hidden_states, _ = self.dense_h_to_4h(hidden_states) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index a0f74ced1d156..c9cf16475ca21 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -130,7 +130,8 @@ class MPTMLP(nn.Module): bias=not config.no_bias, linear_method=linear_method, ) - self.act = get_act_fn("gelu") + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn("gelu", quant_config, intermediate_size) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 2dde92577bff6..2d1df29a59cf1 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module): linear_method=linear_method, ) self.do_layer_norm_before = config.do_layer_norm_before - self.activation_fn = get_act_fn(config.activation_function) + quant_config = getattr(linear_method, "quant_config", None) + self.activation_fn = get_act_fn(config.activation_function, + quant_config, config.ffn_dim) self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, @@ -251,7 +253,7 @@ class OPTDecoder(nn.Module): inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) + inputs_embeds, _ = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds for i in range(len(self.layers)): @@ -266,7 +268,7 @@ class OPTDecoder(nn.Module): if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: - hidden_states = self.project_out(hidden_states) + hidden_states, _ = self.project_out(hidden_states) return hidden_states diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 2ae88519a6cf7..fbf7aa0a1491e 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -168,7 +168,9 @@ class PhiMLP(nn.Module): config.hidden_size, linear_method=linear_method, ) - self.act = get_act_fn(config.activation_function) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + n_inner) def forward(self, hidden_states): hidden_states, _ = self.fc1(hidden_states)