[Model] MLPSpeculator speculative decoding support (#4947)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
Joshua Rosenkranz 2024-06-20 20:23:12 -04:00 committed by GitHub
parent 6c5b7af152
commit b12518d3cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 523 additions and 40 deletions

View File

@ -0,0 +1,59 @@
import gc
import time
from typing import List
from vllm import LLM, SamplingParams
def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
print(f"text: {generated_text!r}")
if __name__ == "__main__":
template = (
"Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n")
# Sample prompts.
prompts = [
"Write about the president of the United States.",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
# Create an LLM without spec decoding
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
print("Without speculation")
time_generation(llm, prompts, sampling_params)
del llm
gc.collect()
# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
enforce_eager=True,
)
print("With speculation")
time_generation(llm, prompts, sampling_params)

View File

@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'

View File

@ -2,8 +2,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len from vllm.spec_decode.util import split_batch_by_proposal_len
def test_get_all_seq_ids(): def test_get_all_seq_ids():

View File

@ -230,7 +230,8 @@ class ModelConfig:
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
) -> None: ) -> None:
total_num_attention_heads = self.hf_text_config.num_attention_heads total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0: if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError( raise ValueError(
@ -238,7 +239,8 @@ class ModelConfig:
" must be divisible by tensor parallel size " " must be divisible by tensor parallel size "
f"({tensor_parallel_size}).") f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_text_config.num_hidden_layers total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0: if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError( raise ValueError(
@ -341,8 +343,8 @@ class ModelConfig:
def get_num_attention_heads(self, def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int: parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \ num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
parallel_config.tensor_parallel_size return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers total_num_hidden_layers = self.hf_text_config.num_hidden_layers
@ -818,7 +820,8 @@ class SpeculativeConfig:
speculative_model (Optional[str]): The name of the speculative speculative_model (Optional[str]): The name of the speculative
model, if provided. model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip the speculative model. Used when testing the ability to skip
speculation for some sequences. speculation for some sequences.
@ -841,24 +844,18 @@ class SpeculativeConfig:
the necessary conditions are met, else None. the necessary conditions are met, else None.
""" """
if speculative_model is None and num_speculative_tokens is None: if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None return None
if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
if (speculative_disable_by_batch_size is not None if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2): and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling " raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}") f"{speculative_disable_by_batch_size=}")
assert (speculative_model is not None
and num_speculative_tokens is not None)
if enable_chunked_prefill: if enable_chunked_prefill:
raise ValueError( raise ValueError(
"Speculative decoding and chunked prefill are " "Speculative decoding and chunked prefill are "
@ -912,6 +909,27 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
) )
if (draft_model_config.hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")
n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
draft_model_config.max_model_len = ( draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len( SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len, speculative_max_model_len,
@ -923,6 +941,12 @@ class SpeculativeConfig:
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config)) target_parallel_config))
if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")
return SpeculativeConfig( return SpeculativeConfig(
draft_model_config, draft_model_config,
draft_parallel_config, draft_parallel_config,

View File

@ -60,6 +60,7 @@ _GENERATION_MODELS = {
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {

View File

@ -0,0 +1,143 @@
import math
from typing import Iterable, List, Tuple
import torch
import torch.nn as nn
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculator(nn.Module):
def __init__(self, config, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False) for i in range(self.max_speculative_tokens)
])
self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim)
for _ in range(self.max_speculative_tokens)
])
self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = Sampler()
def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
# b x 1
last_tokens = input_ids.unsqueeze(1)
next_tokens = []
for head_index in range(num_predict_tokens):
# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
logits = self.logits_processor(self.head[head_index].weight,
states, sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -794,6 +794,9 @@ class SamplerOutput:
# Spec decode metrics populated by workers. # Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self.outputs[idx] return self.outputs[idx]
@ -842,6 +845,46 @@ class PoolerOutput:
self.__class__) and self.outputs == other.outputs self.__class__) and self.outputs == other.outputs
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
class HiddenStates:
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor):
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self.seq_ids:
# Batch contents changed - prune removed sequences.
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
self.seq_ids = seq_ids
@dataclass @dataclass
class ExecuteModelRequest: class ExecuteModelRequest:
"""The model execution request.""" """The model execution request."""
@ -857,6 +900,8 @@ class ExecuteModelRequest:
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# The number of requests in the running queue. # The number of requests in the running queue.
running_queue_size: int = 0 running_queue_size: int = 0
# Optional hidden states from prior step.
previous_hidden_states: Optional[HiddenStates] = None
def clone( def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata] self, seq_group_metadata_list: List[SequenceGroupMetadata]
@ -869,4 +914,5 @@ class ExecuteModelRequest:
blocks_to_copy=self.blocks_to_copy.copy(), blocks_to_copy=self.blocks_to_copy.copy(),
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
) )

View File

@ -4,11 +4,10 @@ from typing import Iterator, List, Tuple
import torch import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
sampler_output_to_torch,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
@ -98,6 +97,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
probs=all_probs, probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
logprobs=spec_logprobs, logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
) )
def _expand_batch( def _expand_batch(

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
@ -46,6 +47,9 @@ class SpeculativeScores:
# tokens and also non-speculative normal decoding. # tokens and also non-speculative normal decoding.
token_ids: torch.Tensor token_ids: torch.Tensor
# Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None
def __repr__(self): def __repr__(self):
return (f"SpeculativeScores(" return (f"SpeculativeScores("
f"probs={self.probs.shape}, " f"probs={self.probs.shape}, "

View File

@ -0,0 +1,87 @@
from typing import List, Optional, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.worker.model_runner import ModelInput
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
"""Worker for MLPSpeculator models.
Not currently compatible with LoRA or chunked prefill.
"""
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For mlp spec worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
num_predict_tokens=sample_len,
sampling_metadata=sampling_metadata)
assert len(model_outputs) == sample_len
return model_outputs, True
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
if not seq_group_metadata_list:
return ModelInput.empty(self.device)
input_tokens: List[int] = []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
input_tokens.extend(tokens)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
input_tokens.append(seq_data.get_last_token_id())
query_lens.append(1)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
return input_tokens_tensor, seq_lens, query_lens

View File

@ -8,16 +8,18 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
SamplerOutput, SequenceGroupMetadata) HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import (create_sequence_group_output, from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@ -104,6 +106,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
else: else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
@ -155,6 +161,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Lazy initiazliation. # Lazy initiazliation.
self.scorer: SpeculativeScorer self.scorer: SpeculativeScorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
""" """
@ -337,6 +347,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
assert len(sampler_output) == 1 assert len(sampler_output) == 1
sampler_output = sampler_output[0] sampler_output = sampler_output[0]
# Store hidden states from target model execution.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
if self.previous_hidden_states is None:
self.previous_hidden_states = HiddenStates(
execute_model_req.seq_group_metadata_list, hidden_states)
else:
self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states)
# Clear device tensors from sampler output. This reduces communication # Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
sampler_output.probs = None sampler_output.probs = None
@ -383,6 +403,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
assert num_lookahead_slots == execute_model_req.num_lookahead_slots assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Pass last hidden states from target model to proposer
execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
@ -466,6 +490,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# metadata. # metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone() accepted_token_ids[original_indices] = accepted_token_ids.clone()
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[1]
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
hs_size)
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
hidden_states)
return accepted_token_ids, logprobs return accepted_token_ids, logprobs
def _create_output_sampler_list( def _create_output_sampler_list(

View File

@ -65,9 +65,13 @@ class Top1Proposer(SpeculativeProposer):
# token_ids is like [batch] format in proposal_len size list, # token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len] # while if it is false, the format would be [proposal_len]
# in batch size list # in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
hidden_states.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest( nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs, seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len, num_lookahead_slots=proposal_len,
previous_hidden_states=hidden_states,
) )
maybe_sampler_output, transposed = self._worker.sampler_output( maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req, execute_model_req=nonzero_execute_model_req,

View File

@ -10,14 +10,6 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SeqId = int SeqId = int
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
def get_all_num_logprobs( def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs. """Given a list of SequenceGroupMetadata, create a list of all num_logprobs.

View File

@ -1,3 +1,4 @@
import contextlib
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -5,7 +6,13 @@ from transformers import PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig) JAISConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
else:
from transformers import AutoConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -16,8 +23,13 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"jais": JAISConfig, "jais": JAISConfig,
"mlp_speculator": MLPSpeculatorConfig,
} }
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
@ -26,10 +38,6 @@ def get_config(model: str,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None) -> PretrainedConfig: rope_theta: Optional[float] = None) -> PretrainedConfig:
try: try:
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, model,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,

View File

@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [ __all__ = [
@ -13,4 +14,5 @@ __all__ = [
"MPTConfig", "MPTConfig",
"RWConfig", "RWConfig",
"JAISConfig", "JAISConfig",
"MLPSpeculatorConfig",
] ]

View File

@ -0,0 +1,50 @@
from typing import List, Optional
from transformers import PretrainedConfig
class MLPSpeculatorConfig(PretrainedConfig):
model_type = "mlp_speculator"
attribute_map = {
"hidden_size": "emb_dim",
}
def __init__(self,
vocab_size: int = 32000,
emb_dim: int = 4096,
inner_dim: int = 0,
n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5,
**kwargs):
"""
Initialize an MLPSpeculatorConfig
Args:
vocab_size: int
the model vocab size
emb_dim: int
the model embedding dimension
inner_dim: int
the inner dimension of the model. If 0, will be the emb_dim.
n_predict: int
the number of lookaheads for the speculator
top_k_tokens_per_head: List[int]
Number of tokens to consider from each head when forming the
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
n_candidates: int
number of child candidates to create per sequence
"""
if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3]
assert len(top_k_tokens_per_head) == n_predict
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.inner_dim = inner_dim
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
super().__init__(**kwargs)

View File

@ -86,6 +86,7 @@ class ModelRunner:
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
return_hidden_states: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
@ -96,6 +97,7 @@ class ModelRunner:
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
@ -116,15 +118,17 @@ class ModelRunner:
self.graph_block_tables = np.zeros( self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
dtype=np.int32) dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config), num_attn_heads,
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
) ) if num_attn_heads else None
# Create processor for multi-modal data # Create processor for multi-modal data
if self.vision_language_config is not None: if self.vision_language_config is not None:
@ -762,11 +766,19 @@ class ModelRunner:
return None return None
# Sample the next token. # Sample the next token.
output = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert seq_group_metadata_list is not None
if seq_group_metadata_list[0].is_prompt:
hidden_states = hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
output.hidden_states = hidden_states
return output return output
@torch.inference_mode() @torch.inference_mode()

View File

@ -70,6 +70,14 @@ class Worker(WorkerBase):
assert not self.lora_config, ( assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.") "To be tested: vision language model with LoRA settings.")
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type !=
"mlp_speculator") else {"return_hidden_states": True}
ModelRunnerClass = (EmbeddingModelRunner if ModelRunnerClass = (EmbeddingModelRunner if
self.model_config.embedding_mode else ModelRunner) self.model_config.embedding_mode else ModelRunner)
self.model_runner = ModelRunnerClass( self.model_runner = ModelRunnerClass(
@ -83,6 +91,7 @@ class Worker(WorkerBase):
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
**speculative_args,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.