mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
Add activation registry (#126)
This commit is contained in:
parent
057daef778
commit
4a151dd453
@ -61,7 +61,7 @@ class LLM:
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
for output in step_outputs:
|
||||
if output.done:
|
||||
if output.finished():
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
|
||||
@ -4,6 +4,21 @@ import torch.nn as nn
|
||||
|
||||
from cacheflow import activation_ops
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_act_fn(act_fn: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in _ACTIVATION_REGISTRY:
|
||||
return _ACTIVATION_REGISTRY[act_fn]
|
||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
@ -27,6 +27,7 @@ from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@ -92,12 +93,7 @@ class GPT2MLP(nn.Module):
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
act_fn = config.activation_function
|
||||
if act_fn != "gelu_new":
|
||||
raise ValueError(f"Unsupported activation: {act_fn}. "
|
||||
"GPT-2 only supports gelu_new for now.")
|
||||
self.act = torch.nn.GELU(approximate="tanh")
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
||||
@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module):
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if config.hidden_act != 'gelu':
|
||||
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
|
||||
'Only gelu is supported for now.')
|
||||
self.act = torch.nn.GELU()
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
assert config.activation_function == 'relu'
|
||||
self.activation_fn = nn.ReLU()
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user