mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:45:36 +08:00
[Core][Model] torch.compile for layernorm in commandr (#3985)
[Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985)
This commit is contained in:
parent
67b4221a61
commit
caada5e50a
@ -48,6 +48,18 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
|
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
|
||||||
|
variance_epsilon)
|
||||||
|
hidden_states = weight.to(torch.float32) * hidden_states
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, param_shape=None, eps=1e-5):
|
def __init__(self, param_shape=None, eps=1e-5):
|
||||||
@ -57,14 +69,9 @@ class LayerNorm(nn.Module):
|
|||||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||||
|
|
||||||
def forward(self, hidden_states, residuals=None):
|
def forward(self, hidden_states, residuals=None):
|
||||||
input_dtype = hidden_states.dtype
|
hidden_states = layer_norm_func(hidden_states, self.weight,
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
self.variance_epsilon)
|
||||||
mean = hidden_states.mean(-1, keepdim=True)
|
return hidden_states, residuals
|
||||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = (hidden_states -
|
|
||||||
mean) * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
|
||||||
return hidden_states.to(input_dtype), residuals
|
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user