[Model] Changes to MLPSpeculator to support tie_weights and input_scale (#5965)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
This commit is contained in:
Thomas Parnell 2024-07-02 01:40:02 +02:00 committed by GitHub
parent e373853e12
commit 54600709b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 23 deletions

View File

@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig from vllm.transformers_utils.configs import MLPSpeculatorConfig
SQRT2 = 2**0.5
class MLPSpeculatorLayerNorm(nn.Module): class MLPSpeculatorLayerNorm(nn.Module):
""" """
@ -26,24 +28,30 @@ class MLPSpeculatorLayerNorm(nn.Module):
Safety term to prevent division by zero. Make sure the chosen value Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8). (i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
""" """
def __init__( def __init__(
self, self,
normalized_shape, normalized_shape,
eps=1e-06, eps=1e-06,
elementwise_scale_and_shift=True,
): ):
super(MLPSpeculatorLayerNorm, self).__init__() super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape)) self.elementwise_scale_and_shift = elementwise_scale_and_shift
self.bias = nn.Parameter(torch.empty(normalized_shape)) if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps self.eps = eps
def forward(self, x): def forward(self, x):
xf = x xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x) x = xf.type_as(x)
x = self.weight * x if self.elementwise_scale_and_shift:
x = x + self.bias x = self.weight * x
x = x + self.bias
return x return x
@ -59,27 +67,60 @@ class MLPSpeculator(nn.Module):
self.max_speculative_tokens = config.num_lookahead_tokens self.max_speculative_tokens = config.num_lookahead_tokens
self.emb = nn.ModuleList([ self.tie_weights = config.tie_weights
VocabParallelEmbedding(config.vocab_size, self.scale_input = config.scale_input
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([ if self.tie_weights:
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), assert (
self.inner_dim, self.n_predict >
bias=False) for i in range(self.max_speculative_tokens) 1), "You cannot tie weights between stages when only 1 exists"
]) embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
self.head = nn.ModuleList([ # the initial projection from the base model may
nn.Linear(self.inner_dim, self.vocab_size, bias=False) # have a different size, so that stays separate.
for _ in range(self.max_speculative_tokens) proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
]) proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.ln = nn.ModuleList([ self.proj = nn.ModuleList([proj_first] + [proj_tied] *
MLPSpeculatorLayerNorm(self.inner_dim) (self.max_speculative_tokens - 1))
for _ in range(self.max_speculative_tokens)
]) head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False)
for i in range(self.max_speculative_tokens)
])
self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict) self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt( self.emb_weight = math.sqrt(
@ -105,6 +146,9 @@ class MLPSpeculator(nn.Module):
# b x 1 x d # b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1) previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1 # b x 1
last_tokens = input_ids.unsqueeze(1) last_tokens = input_ids.unsqueeze(1)

View File

@ -17,6 +17,8 @@ class MLPSpeculatorConfig(PretrainedConfig):
n_predict: int = 3, n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None, top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5, n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
**kwargs): **kwargs):
""" """
Initialize an MLPSpeculatorConfig Initialize an MLPSpeculatorConfig
@ -38,6 +40,14 @@ class MLPSpeculatorConfig(PretrainedConfig):
NOTE: This parameter is currently unused. NOTE: This parameter is currently unused.
n_candidates: int n_candidates: int
number of child candidates to create per sequence number of child candidates to create per sequence
tie_weights: bool
If true, use a single set of weights for every model
head/stage after the first. The initial projection
from the base model may have a different size, so that
stays separate.
scale_input: bool
if True, will scale the initial hidden states from
the base model.
""" """
if top_k_tokens_per_head is None: if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3] top_k_tokens_per_head = [5, 4, 3]
@ -49,5 +59,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
self.top_k_tokens_per_head = top_k_tokens_per_head self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict self.num_lookahead_tokens = n_predict
self.tie_weights = tie_weights
self.scale_input = scale_input
super().__init__(**kwargs) super().__init__(**kwargs)