mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[Core][Model] torch.compile for layernorm in commandr (#3985)
[Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985)
This commit is contained in:
parent
67b4221a61
commit
caada5e50a
@ -48,6 +48,18 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@torch.compile
|
||||
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
||||
variance_epsilon)
|
||||
hidden_states = weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, param_shape=None, eps=1e-5):
|
||||
@ -57,14 +69,9 @@ class LayerNorm(nn.Module):
|
||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, hidden_states, residuals=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states -
|
||||
mean) * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype), residuals
|
||||
hidden_states = layer_norm_func(hidden_states, self.weight,
|
||||
self.variance_epsilon)
|
||||
return hidden_states, residuals
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user