vllm/vllm/model_executor/models/mlp_speculator.py
Jee Jee Li 9d1c474704
[LoRA][1/N]Remove LoRA extra vocab (#28382)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-11-11 11:06:21 -08:00

236 lines
8.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .utils import maybe_prefix
SQRT2 = 2**0.5
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_scale_and_shift=True,
):
super().__init__()
self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
if self.elementwise_scale_and_shift:
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculator(nn.Module):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Speculators"
https://arxiv.org/pdf/2404.19124
Trained speculators of this type are available on HF hub at:
https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input
if self.tie_weights:
assert self.n_predict > 1, (
"You cannot tie weights between stages when only 1 exists"
)
embedding = VocabParallelEmbedding(
config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size
)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.proj = nn.ModuleList(
[proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)
)
self.head = nn.ModuleList(
[
ParallelLMHead(
self.vocab_size,
self.inner_dim,
bias=False,
prefix=maybe_prefix(prefix, f"head.{i}"),
)
for i in range(self.max_speculative_tokens)
]
)
ln = MLPSpeculatorLayerNorm(
self.inner_dim, elementwise_scale_and_shift=True
)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList(
[
VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
)
for _ in range(self.max_speculative_tokens)
]
)
self.proj = nn.ModuleList(
[
nn.Linear(
(self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False,
)
for i in range(self.max_speculative_tokens)
]
)
self.head = nn.ModuleList(
[
ParallelLMHead(
self.vocab_size,
self.inner_dim,
bias=False,
prefix=maybe_prefix(prefix, f"head.{i}"),
)
for i in range(self.max_speculative_tokens)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
self.inner_dim, elementwise_scale_and_shift=True
)
for _ in range(self.max_speculative_tokens)
]
)
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False
)
self.state_weight = 0.5 ** (0.5 / config.n_predict)
self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(
config.vocab_size, config.vocab_size, 1.0
)
# NOTE(woosuk): This method is commented out because it is old code
# using V0. We should either port it to V1 or remove it.
# def generate_proposals(
# self,
# input_ids: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# num_predict_tokens: int,
# sampling_metadata: SamplingMetadata,
# ) -> list[SamplerOutput]:
# if num_predict_tokens > self.max_speculative_tokens:
# raise ValueError(f"Max speculative tokens for model is "
# f"{self.max_speculative_tokens}, but "
# f"{num_predict_tokens} were requested")
# # b x 1 x d
# previous_hidden_states = previous_hidden_states.unsqueeze(1)
# if self.scale_input:
# previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# # b x 1
# last_tokens = input_ids.unsqueeze(1)
# next_tokens = []
# for head_index in range(num_predict_tokens):
# # Project and predict
# z = self.emb[head_index](last_tokens) # b k d
# states = self.proj[head_index](previous_hidden_states)
# # Weighted add of state_weight*state and emb_weight*z
# # Let subsequent LN take care of denominator
# # state_weight is close to 1, so shouldn't be any precision issues
# states.add_(z, alpha=self.emb_weight / self.state_weight)
# states = self.activation(self.ln[head_index](states)) # b k d
# previous_hidden_states = states
# # TODO: not yet supporting top_k_tokens_per_head
# states = states.flatten(0, 1)
# logits = self.logits_processor(self.head[head_index], states,
# sampling_metadata)
# output = self.sampler(logits, sampling_metadata)
# last_tokens = output.sampled_token_ids
# next_tokens.append(output)
# return next_tokens
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params