[Core][Model] torch.compile for layernorm in commandr (#3985)

[Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985)
This commit is contained in:
youkaichao 2024-04-10 18:48:26 -07:00 committed by GitHub
parent 67b4221a61
commit caada5e50a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,6 +48,18 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class LayerNorm(nn.Module):
def __init__(self, param_shape=None, eps=1e-5):
@ -57,14 +69,9 @@ class LayerNorm(nn.Module):
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype), residuals
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()