[Doc] [SpecDecode] Update MLPSpeculator documentation (#7100)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-08-06 01:29:43 +02:00 committed by GitHub
parent dfb1a15dcb
commit 789937af2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 0 deletions

View File

@ -69,6 +69,55 @@ matching n-grams in the prompt. For more information read `this thread. <https:/
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Speculating using MLP speculators
---------------------------------
The following code configures vLLM to use speculative decoding where proposals are generated by
draft models that conditioning draft predictions on both context vectors and sampled tokens.
For more information see `this blog <https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/>`_ or
`this technical report <https://arxiv.org/abs/2404.19124>`_.
.. code-block:: python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
speculative_model="ibm-fms/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Note that these speculative models currently need to be run without tensor parallelism, although
it is possible to run the main model using tensor parallelism (see example above). Since the
speculative models are relatively small, we still see significant speedups. However, this
limitation will be fixed in a future release.
A variety of speculative models of this type are available on HF hub:
* `llama-13b-accelerator <https://huggingface.co/ibm-fms/llama-13b-accelerator>`_
* `llama3-8b-accelerator <https://huggingface.co/ibm-fms/llama3-8b-accelerator>`_
* `codellama-34b-accelerator <https://huggingface.co/ibm-fms/codellama-34b-accelerator>`_
* `llama2-70b-accelerator <https://huggingface.co/ibm-fms/llama2-70b-accelerator>`_
* `llama3-70b-accelerator <https://huggingface.co/ibm-fms/llama3-70b-accelerator>`_
* `granite-3b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-3b-code-instruct-accelerator>`_
* `granite-8b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-8b-code-instruct-accelerator>`_
* `granite-7b-instruct-accelerator <https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator>`_
* `granite-20b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator>`_
Resources for vLLM contributors
-------------------------------
* `A Hacker's Guide to Speculative Decoding in vLLM <https://www.youtube.com/watch?v=9wNAgpX6z_4>`_

View File

@ -56,6 +56,15 @@ class MLPSpeculatorLayerNorm(nn.Module):
class MLPSpeculator(nn.Module):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Speculators"
https://arxiv.org/pdf/2404.19124
Trained speculators of this type are available on HF hub at:
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
"""
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()