Add activation registry (#126)

This commit is contained in:
Woosuk Kwon 2023-05-25 00:09:07 -07:00 committed by GitHub
parent 057daef778
commit 4a151dd453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 13 deletions

View File

@ -61,7 +61,7 @@ class LLM:
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
for output in step_outputs:
if output.done:
if output.finished():
outputs.append(output)
if use_tqdm:
pbar.update(1)

View File

@ -4,6 +4,21 @@ import torch.nn as nn
from cacheflow import activation_ops
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"relu": nn.ReLU(),
}
def get_act_fn(act_fn: str) -> nn.Module:
"""Get an activation function by name."""
act_fn = act_fn.lower()
if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.")
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.

View File

@ -27,6 +27,7 @@ from torch import nn
from transformers import GPT2Config
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import get_act_fn
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@ -92,12 +93,7 @@ class GPT2MLP(nn.Module):
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
perform_initialization=False)
act_fn = config.activation_function
if act_fn != "gelu_new":
raise ValueError(f"Unsupported activation: {act_fn}. "
"GPT-2 only supports gelu_new for now.")
self.act = torch.nn.GELU(approximate="tanh")
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)

View File

@ -26,6 +26,7 @@ from torch import nn
from transformers import GPTNeoXConfig
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import get_act_fn
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
if config.hidden_act != 'gelu':
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
'Only gelu is supported for now.')
self.act = torch.nn.GELU()
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)

View File

@ -26,6 +26,7 @@ from torch import nn
from transformers import OPTConfig
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import get_act_fn
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module):
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
assert config.activation_function == 'relu'
self.activation_fn = nn.ReLU()
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)