From 4a151dd45308f812d125b6ea239b1730b7f647c4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 25 May 2023 00:09:07 -0700 Subject: [PATCH] Add activation registry (#126) --- cacheflow/entrypoints/llm.py | 2 +- cacheflow/model_executor/layers/activation.py | 15 +++++++++++++++ cacheflow/model_executor/models/gpt2.py | 8 ++------ cacheflow/model_executor/models/gpt_neox.py | 6 ++---- cacheflow/model_executor/models/opt.py | 4 ++-- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index acb9a7473ad98..164bee4409740 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -61,7 +61,7 @@ class LLM: while self.llm_server.has_unfinished_requests(): step_outputs = self.llm_server.step() for output in step_outputs: - if output.done: + if output.finished(): outputs.append(output) if use_tqdm: pbar.update(1) diff --git a/cacheflow/model_executor/layers/activation.py b/cacheflow/model_executor/layers/activation.py index f82d57769fa86..467609a500a64 100644 --- a/cacheflow/model_executor/layers/activation.py +++ b/cacheflow/model_executor/layers/activation.py @@ -4,6 +4,21 @@ import torch.nn as nn from cacheflow import activation_ops +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. + "gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. + "relu": nn.ReLU(), +} + + +def get_act_fn(act_fn: str) -> nn.Module: + """Get an activation function by name.""" + act_fn = act_fn.lower() + if act_fn in _ACTIVATION_REGISTRY: + return _ACTIVATION_REGISTRY[act_fn] + raise ValueError(f"Activation function {act_fn!r} is not supported.") + class SiluAndMul(nn.Module): """An activation function for SwiGLU. diff --git a/cacheflow/model_executor/models/gpt2.py b/cacheflow/model_executor/models/gpt2.py index 690bd7803c74c..5802d399bcc6f 100644 --- a/cacheflow/model_executor/models/gpt2.py +++ b/cacheflow/model_executor/models/gpt2.py @@ -27,6 +27,7 @@ from torch import nn from transformers import GPT2Config from cacheflow.model_executor.input_metadata import InputMetadata +from cacheflow.model_executor.layers.activation import get_act_fn from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention from cacheflow.model_executor.layers.sampler import Sampler from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator, @@ -92,12 +93,7 @@ class GPT2MLP(nn.Module): self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=True, input_is_parallel=True, perform_initialization=False) - - act_fn = config.activation_function - if act_fn != "gelu_new": - raise ValueError(f"Unsupported activation: {act_fn}. " - "GPT-2 only supports gelu_new for now.") - self.act = torch.nn.GELU(approximate="tanh") + self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states) diff --git a/cacheflow/model_executor/models/gpt_neox.py b/cacheflow/model_executor/models/gpt_neox.py index c98514fd49ed9..1804c2e710d7a 100644 --- a/cacheflow/model_executor/models/gpt_neox.py +++ b/cacheflow/model_executor/models/gpt_neox.py @@ -26,6 +26,7 @@ from torch import nn from transformers import GPTNeoXConfig from cacheflow.model_executor.input_metadata import InputMetadata +from cacheflow.model_executor.layers.activation import get_act_fn from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention from cacheflow.model_executor.layers.sampler import Sampler from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator, @@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module): self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True, perform_initialization=False) - if config.hidden_act != 'gelu': - raise ValueError(f'Unsupported activation: {config.hidden_act}. ' - 'Only gelu is supported for now.') - self.act = torch.nn.GELU() + self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states): hidden_states, _ = self.dense_h_to_4h(hidden_states) diff --git a/cacheflow/model_executor/models/opt.py b/cacheflow/model_executor/models/opt.py index e340f68370be9..864f001a1c4a7 100644 --- a/cacheflow/model_executor/models/opt.py +++ b/cacheflow/model_executor/models/opt.py @@ -26,6 +26,7 @@ from torch import nn from transformers import OPTConfig from cacheflow.model_executor.input_metadata import InputMetadata +from cacheflow.model_executor.layers.activation import get_act_fn from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention from cacheflow.model_executor.layers.sampler import Sampler from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator, @@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module): bias=config.enable_bias, ) self.do_layer_norm_before = config.do_layer_norm_before - assert config.activation_function == 'relu' - self.activation_fn = nn.ReLU() + self.activation_fn = get_act_fn(config.activation_function) self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)