mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Model] MLPSpeculator speculative decoding support (#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
parent
6c5b7af152
commit
b12518d3cf
59
examples/offline_inference_mlpspeculator.py
Normal file
59
examples/offline_inference_mlpspeculator.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def time_generation(llm: LLM, prompts: List[str],
|
||||||
|
sampling_params: SamplingParams):
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||||
|
# objects that contain the prompt, generated text, and other information.
|
||||||
|
# Warmup first
|
||||||
|
llm.generate(prompts, sampling_params)
|
||||||
|
llm.generate(prompts, sampling_params)
|
||||||
|
start = time.time()
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
end = time.time()
|
||||||
|
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
template = (
|
||||||
|
"Below is an instruction that describes a task. Write a response "
|
||||||
|
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||||
|
"\n\n### Response:\n")
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Write about the president of the United States.",
|
||||||
|
]
|
||||||
|
prompts = [template.format(prompt) for prompt in prompts]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
|
||||||
|
|
||||||
|
# Create an LLM without spec decoding
|
||||||
|
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
|
||||||
|
|
||||||
|
print("Without speculation")
|
||||||
|
time_generation(llm, prompts, sampling_params)
|
||||||
|
|
||||||
|
del llm
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Create an LLM with spec decoding
|
||||||
|
llm = LLM(
|
||||||
|
model="meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
speculative_model="ibm-fms/llama-13b-accelerator",
|
||||||
|
# These are currently required for MLPSpeculator decoding
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("With speculation")
|
||||||
|
time_generation(llm, prompts, sampling_params)
|
||||||
@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
|
|||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
sampler_output = MagicMock(spec=SamplerOutput)
|
||||||
|
sampler_output.hidden_states = None
|
||||||
|
target_worker.execute_model.return_value = [sampler_output]
|
||||||
|
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
target_worker.device = 'cuda'
|
target_worker.device = 'cuda'
|
||||||
@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
sampler_output = MagicMock(spec=SamplerOutput)
|
||||||
|
sampler_output.hidden_states = None
|
||||||
|
target_worker.execute_model.return_value = [sampler_output]
|
||||||
|
|
||||||
draft_worker.device = 'cuda'
|
draft_worker.device = 'cuda'
|
||||||
target_worker.device = 'cuda'
|
target_worker.device = 'cuda'
|
||||||
|
|||||||
@ -2,8 +2,8 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.sequence import SequenceGroupMetadata
|
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||||
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len
|
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_seq_ids():
|
def test_get_all_seq_ids():
|
||||||
|
|||||||
@ -230,7 +230,8 @@ class ModelConfig:
|
|||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
) -> None:
|
) -> None:
|
||||||
total_num_attention_heads = self.hf_text_config.num_attention_heads
|
total_num_attention_heads = getattr(self.hf_text_config,
|
||||||
|
"num_attention_heads", 0)
|
||||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||||
if total_num_attention_heads % tensor_parallel_size != 0:
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -238,7 +239,8 @@ class ModelConfig:
|
|||||||
" must be divisible by tensor parallel size "
|
" must be divisible by tensor parallel size "
|
||||||
f"({tensor_parallel_size}).")
|
f"({tensor_parallel_size}).")
|
||||||
|
|
||||||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||||
|
"num_hidden_layers", 0)
|
||||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -341,8 +343,8 @@ class ModelConfig:
|
|||||||
|
|
||||||
def get_num_attention_heads(self,
|
def get_num_attention_heads(self,
|
||||||
parallel_config: "ParallelConfig") -> int:
|
parallel_config: "ParallelConfig") -> int:
|
||||||
return self.hf_text_config.num_attention_heads // \
|
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
|
||||||
parallel_config.tensor_parallel_size
|
return num_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
@ -818,7 +820,8 @@ class SpeculativeConfig:
|
|||||||
speculative_model (Optional[str]): The name of the speculative
|
speculative_model (Optional[str]): The name of the speculative
|
||||||
model, if provided.
|
model, if provided.
|
||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
num_speculative_tokens (Optional[int]): The number of speculative
|
||||||
tokens, if provided.
|
tokens, if provided. Will default to the number in the draft
|
||||||
|
model config if present, otherwise is required.
|
||||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||||
the speculative model. Used when testing the ability to skip
|
the speculative model. Used when testing the ability to skip
|
||||||
speculation for some sequences.
|
speculation for some sequences.
|
||||||
@ -841,24 +844,18 @@ class SpeculativeConfig:
|
|||||||
the necessary conditions are met, else None.
|
the necessary conditions are met, else None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if speculative_model is None and num_speculative_tokens is None:
|
if speculative_model is None:
|
||||||
|
if num_speculative_tokens is not None:
|
||||||
|
raise ValueError("num_speculative_tokens was provided without "
|
||||||
|
"speculative_model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if speculative_model is not None and num_speculative_tokens is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected both speculative_model and "
|
|
||||||
"num_speculative_tokens to be provided, but found "
|
|
||||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
|
||||||
|
|
||||||
if (speculative_disable_by_batch_size is not None
|
if (speculative_disable_by_batch_size is not None
|
||||||
and speculative_disable_by_batch_size < 2):
|
and speculative_disable_by_batch_size < 2):
|
||||||
raise ValueError("Expect the batch size threshold of disabling "
|
raise ValueError("Expect the batch size threshold of disabling "
|
||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{speculative_disable_by_batch_size=}")
|
f"{speculative_disable_by_batch_size=}")
|
||||||
|
|
||||||
assert (speculative_model is not None
|
|
||||||
and num_speculative_tokens is not None)
|
|
||||||
|
|
||||||
if enable_chunked_prefill:
|
if enable_chunked_prefill:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Speculative decoding and chunked prefill are "
|
"Speculative decoding and chunked prefill are "
|
||||||
@ -912,6 +909,27 @@ class SpeculativeConfig:
|
|||||||
max_logprobs=target_model_config.max_logprobs,
|
max_logprobs=target_model_config.max_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (draft_model_config.hf_config.model_type == "mlp_speculator"
|
||||||
|
and target_parallel_config.world_size != 1):
|
||||||
|
# MLPSpeculator TP support will be added very soon
|
||||||
|
raise ValueError(
|
||||||
|
"Speculative decoding with mlp_speculator models does not "
|
||||||
|
"yet support distributed inferencing (TP > 1).")
|
||||||
|
|
||||||
|
n_predict = getattr(draft_model_config.hf_config, "n_predict",
|
||||||
|
None)
|
||||||
|
if n_predict is not None:
|
||||||
|
if num_speculative_tokens is None:
|
||||||
|
# Default to max value defined in draft model config.
|
||||||
|
num_speculative_tokens = n_predict
|
||||||
|
elif num_speculative_tokens > n_predict:
|
||||||
|
# Verify provided value doesn't exceed the maximum
|
||||||
|
# supported by the draft model.
|
||||||
|
raise ValueError(
|
||||||
|
"Expected both speculative_model and "
|
||||||
|
"num_speculative_tokens to be provided, but found "
|
||||||
|
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||||
|
|
||||||
draft_model_config.max_model_len = (
|
draft_model_config.max_model_len = (
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
speculative_max_model_len,
|
speculative_max_model_len,
|
||||||
@ -923,6 +941,12 @@ class SpeculativeConfig:
|
|||||||
SpeculativeConfig.create_draft_parallel_config(
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
target_parallel_config))
|
target_parallel_config))
|
||||||
|
|
||||||
|
if num_speculative_tokens is None:
|
||||||
|
raise ValueError(
|
||||||
|
"num_speculative_tokens must be provided with "
|
||||||
|
"speculative_model unless the draft model config contains an "
|
||||||
|
"n_predict parameter.")
|
||||||
|
|
||||||
return SpeculativeConfig(
|
return SpeculativeConfig(
|
||||||
draft_model_config,
|
draft_model_config,
|
||||||
draft_parallel_config,
|
draft_parallel_config,
|
||||||
|
|||||||
@ -60,6 +60,7 @@ _GENERATION_MODELS = {
|
|||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
|
|||||||
143
vllm/model_executor/models/mlp_speculator.py
Normal file
143
vllm/model_executor/models/mlp_speculator.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import math
|
||||||
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
A L2 normalization implementation
|
||||||
|
...
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
normalized_shape : int
|
||||||
|
Dimensionality of input data (size of final tensor axis)
|
||||||
|
eps : float
|
||||||
|
Safety term to prevent division by zero. Make sure the chosen value
|
||||||
|
fits in the range of your encoding scheme
|
||||||
|
(i.e. fp16 requires eps >= 6e-8).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=1e-06,
|
||||||
|
):
|
||||||
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.empty(normalized_shape))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
x = self.weight * x
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.n_predict = config.n_predict
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.emb_dim = config.emb_dim
|
||||||
|
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
|
||||||
|
else config.emb_dim
|
||||||
|
|
||||||
|
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
|
||||||
|
self.n_predict)
|
||||||
|
|
||||||
|
self.emb = nn.ModuleList([
|
||||||
|
VocabParallelEmbedding(config.vocab_size,
|
||||||
|
self.inner_dim,
|
||||||
|
org_num_embeddings=config.vocab_size)
|
||||||
|
for _ in range(self.max_speculative_tokens)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.proj = nn.ModuleList([
|
||||||
|
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
|
||||||
|
self.inner_dim,
|
||||||
|
bias=False) for i in range(self.max_speculative_tokens)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.head = nn.ModuleList([
|
||||||
|
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||||
|
for _ in range(self.max_speculative_tokens)
|
||||||
|
])
|
||||||
|
self.ln = nn.ModuleList([
|
||||||
|
MLPSpeculatorLayerNorm(self.inner_dim)
|
||||||
|
for _ in range(self.max_speculative_tokens)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.state_weight = 0.5**(0.5 / config.n_predict)
|
||||||
|
self.emb_weight = math.sqrt(
|
||||||
|
(1 - self.state_weight**2) * (self.inner_dim / 2))
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.config = config
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
|
config.vocab_size, 1.0)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def generate_proposals(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
num_predict_tokens: int,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
if num_predict_tokens > self.max_speculative_tokens:
|
||||||
|
raise ValueError(f"Max speculative tokens for model is "
|
||||||
|
f"{self.max_speculative_tokens}, but "
|
||||||
|
f"{num_predict_tokens} were requested")
|
||||||
|
|
||||||
|
# b x 1 x d
|
||||||
|
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
|
# b x 1
|
||||||
|
last_tokens = input_ids.unsqueeze(1)
|
||||||
|
|
||||||
|
next_tokens = []
|
||||||
|
|
||||||
|
for head_index in range(num_predict_tokens):
|
||||||
|
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb[head_index](last_tokens) # b k d
|
||||||
|
states = self.proj[head_index](previous_hidden_states)
|
||||||
|
|
||||||
|
# Weighted add of state_weight*state and emb_weight*z
|
||||||
|
# Let subsequent LN take care of denominator
|
||||||
|
# state_weight is close to 1, so shouldn't be any precision issues
|
||||||
|
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||||
|
|
||||||
|
states = self.activation(self.ln[head_index](states)) # b k d
|
||||||
|
# TODO: not yet supporting top_k_tokens_per_head
|
||||||
|
previous_hidden_states = states
|
||||||
|
|
||||||
|
logits = self.logits_processor(self.head[head_index].weight,
|
||||||
|
states, sampling_metadata)
|
||||||
|
|
||||||
|
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
||||||
|
last_tokens = output.sampled_token_ids
|
||||||
|
next_tokens.append(output)
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
param = params_dict[name.replace("speculator.", "")]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
@ -794,6 +794,9 @@ class SamplerOutput:
|
|||||||
# Spec decode metrics populated by workers.
|
# Spec decode metrics populated by workers.
|
||||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
# Optional last hidden states from the model.
|
||||||
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self.outputs[idx]
|
return self.outputs[idx]
|
||||||
|
|
||||||
@ -842,6 +845,46 @@ class PoolerOutput:
|
|||||||
self.__class__) and self.outputs == other.outputs
|
self.__class__) and self.outputs == other.outputs
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_seq_ids(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||||
|
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||||
|
sequence ids.
|
||||||
|
"""
|
||||||
|
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||||
|
|
||||||
|
|
||||||
|
class HiddenStates:
|
||||||
|
"""Hidden states corresponding to in-progress sequences.
|
||||||
|
Used in speculative decoding to pass hidden states from
|
||||||
|
the target model to the proposer model in the subsequent step.
|
||||||
|
|
||||||
|
seq_ids are the sequence ids of each entry of the batch
|
||||||
|
dimension of the hidden_states tensor"""
|
||||||
|
|
||||||
|
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
hidden_states: torch.Tensor):
|
||||||
|
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||||
|
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
||||||
|
self.hidden_states: torch.Tensor = hidden_states
|
||||||
|
|
||||||
|
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
hidden_states: torch.Tensor) -> None:
|
||||||
|
"""Update hidden states from target model invocation."""
|
||||||
|
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||||
|
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||||
|
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||||
|
|
||||||
|
def prune(self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||||
|
"""Prune to provided list of sequence ids."""
|
||||||
|
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||||
|
if seq_ids != self.seq_ids:
|
||||||
|
# Batch contents changed - prune removed sequences.
|
||||||
|
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||||
|
self.hidden_states = self.hidden_states[index]
|
||||||
|
self.seq_ids = seq_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExecuteModelRequest:
|
class ExecuteModelRequest:
|
||||||
"""The model execution request."""
|
"""The model execution request."""
|
||||||
@ -857,6 +900,8 @@ class ExecuteModelRequest:
|
|||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
# The number of requests in the running queue.
|
# The number of requests in the running queue.
|
||||||
running_queue_size: int = 0
|
running_queue_size: int = 0
|
||||||
|
# Optional hidden states from prior step.
|
||||||
|
previous_hidden_states: Optional[HiddenStates] = None
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
@ -869,4 +914,5 @@ class ExecuteModelRequest:
|
|||||||
blocks_to_copy=self.blocks_to_copy.copy(),
|
blocks_to_copy=self.blocks_to_copy.copy(),
|
||||||
num_lookahead_slots=self.num_lookahead_slots,
|
num_lookahead_slots=self.num_lookahead_slots,
|
||||||
running_queue_size=self.running_queue_size,
|
running_queue_size=self.running_queue_size,
|
||||||
|
previous_hidden_states=self.previous_hidden_states,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,11 +4,10 @@ from typing import Iterator, List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata, get_all_seq_ids)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||||
sampler_output_to_torch,
|
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
@ -98,6 +97,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
probs=all_probs,
|
probs=all_probs,
|
||||||
token_ids=all_tokens,
|
token_ids=all_tokens,
|
||||||
logprobs=spec_logprobs,
|
logprobs=spec_logprobs,
|
||||||
|
hidden_states=target_sampler_output.hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_batch(
|
def _expand_batch(
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -46,6 +47,9 @@ class SpeculativeScores:
|
|||||||
# tokens and also non-speculative normal decoding.
|
# tokens and also non-speculative normal decoding.
|
||||||
token_ids: torch.Tensor
|
token_ids: torch.Tensor
|
||||||
|
|
||||||
|
# Optional last hidden states from the scoring model.
|
||||||
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"SpeculativeScores("
|
return (f"SpeculativeScores("
|
||||||
f"probs={self.probs.shape}, "
|
f"probs={self.probs.shape}, "
|
||||||
|
|||||||
87
vllm/spec_decode/mlp_speculator_worker.py
Normal file
87
vllm/spec_decode/mlp_speculator_worker.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||||
|
from vllm.worker.model_runner import ModelInput
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||||
|
"""Worker for MLPSpeculator models.
|
||||||
|
|
||||||
|
Not currently compatible with LoRA or chunked prefill.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sampler_output(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
sample_len: int,
|
||||||
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
|
"""Run the model forward pass to generate sample_len future tokens.
|
||||||
|
Returns the list of sampler output, one per layer, along with indicator
|
||||||
|
of whether torch tensor in sampler output need to be transposed in
|
||||||
|
latter sampler_output_to_torch logic.
|
||||||
|
|
||||||
|
For mlp spec worker, this indicator shall be True.
|
||||||
|
"""
|
||||||
|
self._raise_if_unsupported(execute_model_req)
|
||||||
|
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
|
(input_tokens, seq_lens,
|
||||||
|
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||||
|
self.model_runner.pin_memory)
|
||||||
|
|
||||||
|
model_outputs = self.model_runner.model.generate_proposals(
|
||||||
|
input_ids=input_tokens,
|
||||||
|
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||||
|
hidden_states,
|
||||||
|
num_predict_tokens=sample_len,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
assert len(model_outputs) == sample_len
|
||||||
|
|
||||||
|
return model_outputs, True
|
||||||
|
|
||||||
|
def _prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||||
|
if not seq_group_metadata_list:
|
||||||
|
return ModelInput.empty(self.device)
|
||||||
|
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
query_lens: List[int] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
is_prompt = seq_group_metadata.is_prompt
|
||||||
|
|
||||||
|
for seq_data in seq_group_metadata.seq_data.values():
|
||||||
|
seq_data_len = seq_data.get_len()
|
||||||
|
if is_prompt:
|
||||||
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = min(
|
||||||
|
seq_data_len,
|
||||||
|
context_len + seq_group_metadata.token_chunk_size)
|
||||||
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
input_tokens.extend(tokens)
|
||||||
|
query_lens.append(seq_len - context_len)
|
||||||
|
else:
|
||||||
|
seq_lens.append(seq_data_len)
|
||||||
|
input_tokens.append(seq_data.get_last_token_id())
|
||||||
|
query_lens.append(1)
|
||||||
|
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
return input_tokens_tensor, seq_lens, query_lens
|
||||||
@ -8,16 +8,18 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
SamplerOutput, SequenceGroupMetadata)
|
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||||
|
get_all_seq_ids)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
|
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||||
get_all_num_logprobs, get_all_seq_ids,
|
get_all_num_logprobs,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
@ -104,6 +106,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||||
ngram_prompt_lookup_max)
|
ngram_prompt_lookup_max)
|
||||||
|
elif draft_worker_kwargs[
|
||||||
|
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||||
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||||
|
disable_bonus_tokens = False
|
||||||
else:
|
else:
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
|
||||||
@ -155,6 +161,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# Lazy initiazliation.
|
# Lazy initiazliation.
|
||||||
self.scorer: SpeculativeScorer
|
self.scorer: SpeculativeScorer
|
||||||
|
|
||||||
|
# Hidden states from target model to pass to proposer
|
||||||
|
# in the subsequent step.
|
||||||
|
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize both scorer and proposer models.
|
"""Initialize both scorer and proposer models.
|
||||||
"""
|
"""
|
||||||
@ -337,6 +347,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
assert len(sampler_output) == 1
|
assert len(sampler_output) == 1
|
||||||
sampler_output = sampler_output[0]
|
sampler_output = sampler_output[0]
|
||||||
|
|
||||||
|
# Store hidden states from target model execution.
|
||||||
|
hidden_states = sampler_output.hidden_states
|
||||||
|
if hidden_states is not None:
|
||||||
|
if self.previous_hidden_states is None:
|
||||||
|
self.previous_hidden_states = HiddenStates(
|
||||||
|
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||||
|
else:
|
||||||
|
self.previous_hidden_states.update(
|
||||||
|
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||||
|
|
||||||
# Clear device tensors from sampler output. This reduces communication
|
# Clear device tensors from sampler output. This reduces communication
|
||||||
# overhead when the engine runs in a different process than the workers.
|
# overhead when the engine runs in a different process than the workers.
|
||||||
sampler_output.probs = None
|
sampler_output.probs = None
|
||||||
@ -383,6 +403,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
"""
|
"""
|
||||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||||
|
|
||||||
|
# Pass last hidden states from target model to proposer
|
||||||
|
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
# Generate proposals using draft worker.
|
# Generate proposals using draft worker.
|
||||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
@ -466,6 +490,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# metadata.
|
# metadata.
|
||||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||||
|
|
||||||
|
hidden_states = proposal_scores.hidden_states
|
||||||
|
if hidden_states is not None:
|
||||||
|
# Contract hidden states based on accepted tokens
|
||||||
|
hs_size = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
|
||||||
|
hs_size)
|
||||||
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||||
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||||
|
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||||
|
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||||
|
# Store hidden states from target model for subsequent decode step
|
||||||
|
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
|
||||||
|
hidden_states)
|
||||||
|
|
||||||
return accepted_token_ids, logprobs
|
return accepted_token_ids, logprobs
|
||||||
|
|
||||||
def _create_output_sampler_list(
|
def _create_output_sampler_list(
|
||||||
|
|||||||
@ -65,9 +65,13 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
# token_ids is like [batch] format in proposal_len size list,
|
# token_ids is like [batch] format in proposal_len size list,
|
||||||
# while if it is false, the format would be [proposal_len]
|
# while if it is false, the format would be [proposal_len]
|
||||||
# in batch size list
|
# in batch size list
|
||||||
|
hidden_states = execute_model_req.previous_hidden_states
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||||
nonzero_execute_model_req = ExecuteModelRequest(
|
nonzero_execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||||
num_lookahead_slots=proposal_len,
|
num_lookahead_slots=proposal_len,
|
||||||
|
previous_hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||||
execute_model_req=nonzero_execute_model_req,
|
execute_model_req=nonzero_execute_model_req,
|
||||||
|
|||||||
@ -10,14 +10,6 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|||||||
SeqId = int
|
SeqId = int
|
||||||
|
|
||||||
|
|
||||||
def get_all_seq_ids(
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
|
||||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
|
||||||
sequence ids.
|
|
||||||
"""
|
|
||||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_num_logprobs(
|
def get_all_num_logprobs(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from typing import Dict, Optional, Type
|
from typing import Dict, Optional, Type
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -5,7 +6,13 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||||
JAISConfig, MPTConfig, RWConfig)
|
JAISConfig, MLPSpeculatorConfig,
|
||||||
|
MPTConfig, RWConfig)
|
||||||
|
|
||||||
|
if VLLM_USE_MODELSCOPE:
|
||||||
|
from modelscope import AutoConfig
|
||||||
|
else:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -16,8 +23,13 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
"jais": JAISConfig,
|
"jais": JAISConfig,
|
||||||
|
"mlp_speculator": MLPSpeculatorConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for name, cls in _CONFIG_REGISTRY.items():
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
AutoConfig.register(name, cls)
|
||||||
|
|
||||||
|
|
||||||
def get_config(model: str,
|
def get_config(model: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -26,10 +38,6 @@ def get_config(model: str,
|
|||||||
rope_scaling: Optional[dict] = None,
|
rope_scaling: Optional[dict] = None,
|
||||||
rope_theta: Optional[float] = None) -> PretrainedConfig:
|
rope_theta: Optional[float] = None) -> PretrainedConfig:
|
||||||
try:
|
try:
|
||||||
if VLLM_USE_MODELSCOPE:
|
|
||||||
from modelscope import AutoConfig
|
|
||||||
else:
|
|
||||||
from transformers import AutoConfig
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model,
|
model,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||||
|
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +14,5 @@ __all__ = [
|
|||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
|
"MLPSpeculatorConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
50
vllm/transformers_utils/configs/mlp_speculator.py
Normal file
50
vllm/transformers_utils/configs/mlp_speculator.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorConfig(PretrainedConfig):
|
||||||
|
model_type = "mlp_speculator"
|
||||||
|
|
||||||
|
attribute_map = {
|
||||||
|
"hidden_size": "emb_dim",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size: int = 32000,
|
||||||
|
emb_dim: int = 4096,
|
||||||
|
inner_dim: int = 0,
|
||||||
|
n_predict: int = 3,
|
||||||
|
top_k_tokens_per_head: Optional[List[int]] = None,
|
||||||
|
n_candidates: int = 5,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Initialize an MLPSpeculatorConfig
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: int
|
||||||
|
the model vocab size
|
||||||
|
emb_dim: int
|
||||||
|
the model embedding dimension
|
||||||
|
inner_dim: int
|
||||||
|
the inner dimension of the model. If 0, will be the emb_dim.
|
||||||
|
n_predict: int
|
||||||
|
the number of lookaheads for the speculator
|
||||||
|
top_k_tokens_per_head: List[int]
|
||||||
|
Number of tokens to consider from each head when forming the
|
||||||
|
candidate tree.
|
||||||
|
For each candidate branch in the tree, head n produces topk[n]
|
||||||
|
additional sub-branches.
|
||||||
|
n_candidates: int
|
||||||
|
number of child candidates to create per sequence
|
||||||
|
"""
|
||||||
|
if top_k_tokens_per_head is None:
|
||||||
|
top_k_tokens_per_head = [5, 4, 3]
|
||||||
|
assert len(top_k_tokens_per_head) == n_predict
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.n_predict = n_predict
|
||||||
|
self.top_k_tokens_per_head = top_k_tokens_per_head
|
||||||
|
self.n_candidates = n_candidates
|
||||||
|
super().__init__(**kwargs)
|
||||||
@ -86,6 +86,7 @@ class ModelRunner:
|
|||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -96,6 +97,7 @@ class ModelRunner:
|
|||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
@ -116,15 +118,17 @@ class ModelRunner:
|
|||||||
self.graph_block_tables = np.zeros(
|
self.graph_block_tables = np.zeros(
|
||||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
|
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||||
|
self.parallel_config)
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
num_attn_heads,
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||||
self.model_config.get_sliding_window(),
|
self.model_config.get_sliding_window(),
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
)
|
) if num_attn_heads else None
|
||||||
|
|
||||||
# Create processor for multi-modal data
|
# Create processor for multi-modal data
|
||||||
if self.vision_language_config is not None:
|
if self.vision_language_config is not None:
|
||||||
@ -762,11 +766,19 @@ class ModelRunner:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output: SamplerOutput = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.return_hidden_states:
|
||||||
|
# we only need to pass hidden states of most recent token
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
if seq_group_metadata_list[0].is_prompt:
|
||||||
|
hidden_states = hidden_states.index_select(
|
||||||
|
0, sampling_metadata.selected_token_indices)
|
||||||
|
output.hidden_states = hidden_states
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -70,6 +70,14 @@ class Worker(WorkerBase):
|
|||||||
assert not self.lora_config, (
|
assert not self.lora_config, (
|
||||||
"To be tested: vision language model with LoRA settings.")
|
"To be tested: vision language model with LoRA settings.")
|
||||||
|
|
||||||
|
# Return hidden states from target model if the draft model is an
|
||||||
|
# mlp_speculator
|
||||||
|
speculative_args = {} if speculative_config is None \
|
||||||
|
or (speculative_config.draft_model_config.model ==
|
||||||
|
model_config.model) \
|
||||||
|
or (speculative_config.draft_model_config.hf_config.model_type !=
|
||||||
|
"mlp_speculator") else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass = (EmbeddingModelRunner if
|
ModelRunnerClass = (EmbeddingModelRunner if
|
||||||
self.model_config.embedding_mode else ModelRunner)
|
self.model_config.embedding_mode else ModelRunner)
|
||||||
self.model_runner = ModelRunnerClass(
|
self.model_runner = ModelRunnerClass(
|
||||||
@ -83,6 +91,7 @@ class Worker(WorkerBase):
|
|||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
|
**speculative_args,
|
||||||
)
|
)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user