From 46cdd59577978f893dbf9c733cacd920011fc7fd Mon Sep 17 00:00:00 2001 From: shangmingc Date: Mon, 17 Feb 2025 11:32:26 +0800 Subject: [PATCH] [Feature][Spec Decode] Simplify the use of Eagle Spec Decode (#12304) Signed-off-by: Shangming Cai --- docs/source/features/spec_decode.md | 16 +- .../spec_decode/e2e/test_eagle_correctness.py | 144 ++++++++++++++++++ tests/spec_decode/test_spec_decode_worker.py | 40 ++++- vllm/config.py | 9 ++ vllm/model_executor/models/eagle.py | 24 ++- vllm/spec_decode/multi_step_worker.py | 12 ++ .../spec_decode/smaller_tp_proposer_worker.py | 19 +++ vllm/spec_decode/spec_decode_worker.py | 27 +++- 8 files changed, 273 insertions(+), 18 deletions(-) diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index 1e468962cc9c5..d2255eff608be 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", tensor_parallel_size=4, - speculative_model="path/to/modified/eagle/model", + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", speculative_draft_tensor_parallel_size=1, ) @@ -190,14 +190,12 @@ for output in outputs: A few important things to consider when using the EAGLE based draft models: -1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be - used directly with vLLM due to differences in the expected layer names and model definition. - To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) - to convert them. Note that this script does not modify the model's weights. - - In the above example, use the script to first convert - the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model - and then use the converted checkpoint as the draft model in vLLM. +1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should + be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). + If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the + [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, + and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using + the latest version of vLLM, please leave a comment or raise an issue. 2. The EAGLE based draft models need to be run without tensor parallelism (i.e. speculative_draft_tensor_parallel_size is set to 1), although diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 6d1803f8bc632..42a84071d94d5 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": "float16", + + # Main model + "model_name": "meta-llama/Llama-2-7b-chat-hf", + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "yuhuili/EAGLE-llama2-chat-7B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("seed", [1]) +def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": "float16", + + # Main model + "model_name": "meta-llama/Meta-Llama-3-8B-Instruct", + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("seed", [1]) +def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": "float16", + + # Main model + "model_name": "Qwen/Qwen2-7B-Instruct", + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("seed", [1]) +def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0) + + if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index eee0f4c89c898..e4b1a178b0c95 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -13,15 +13,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) +from vllm.worker.worker import Worker from .test_utils import mock_spec_decode_sampler -from .utils import create_batch, create_sampler_output_list, mock_worker +from .utils import (create_batch, create_sampler_output_list, create_worker, + mock_worker) @pytest.mark.parametrize('k', [1, 2, 6]) @@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str): worker.execute_model(execute_model_req=execute_model_req) # but first draft still counted assert draft_worker.get_spec_proposals.call_count == 1 + + +def test_correctly_load_weight_for_eagle(): + """ + Verify SpecDecodeWorker loads lm_head weight for eagle correctly. + """ + seed = 100 + block_size = 32 + num_gpu_blocks = 8096 // block_size + target_worker = create_worker( + Worker, + "JackFram/llama-68m", + block_size, + num_gpu_blocks, + seed, + ) + draft_worker = create_worker( + MultiStepWorker, + "abhigoyal/vllm-eagle-llama-68m-random", + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + + spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler") + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False) + worker.proposer_worker.maybe_load_lm_head_weight( + target_worker.model_runner.model.lm_head.weight.data) + assert torch.allclose( + worker.proposer_worker.worker.model_runner.model.lm_head.weight.data, + worker.scorer_worker.model_runner.model.lm_head.weight.data) diff --git a/vllm/config.py b/vllm/config.py index 07499d5abbed4..5c220ed136301 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1833,6 +1833,15 @@ class SpeculativeConfig: draft_hf_config = draft_model_config.hf_config + # Detect EAGLE prefix to replace hf_config for EAGLE draft_model + if "eagle-" in draft_model_config.model.lower(): + from vllm.transformers_utils.configs.eagle import EAGLEConfig + if isinstance(draft_model_config.hf_config, EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig(draft_model_config.hf_config) + draft_model_config.hf_config = eagle_config + if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 373a728be89cb..ab3f0dc07f4da 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors from .utils import maybe_prefix +logger = init_logger(__name__) + class DummyInputLayerNorm(nn.Module): @@ -190,8 +193,8 @@ class EAGLE(nn.Module): default_weight_loader) weight_loader(self.fc.bias, loaded_weight) else: - raise ValueError("Found bias in the loaded weights " - "but the model config doesn't have bias") + logger.warning_once("Found bias in the loaded weights but " + "the model config doesn't have bias.") elif name.startswith("model.lm_head.") or name.startswith( "model.model."): model_weights[name.split("model.", 1)[-1]] = loaded_weight @@ -200,12 +203,21 @@ class EAGLE(nn.Module): else: model_weights[f"model.{name}"] = loaded_weight - lm_head_weight = model_weights.pop("lm_head.weight") + if "lm_head.weight" in model_weights: + lm_head_weight = model_weights.pop("lm_head.weight") - if self.token_map is not None and\ - lm_head_weight.shape[0] > self.token_map.shape[0]: + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: - lm_head_weight = lm_head_weight[self.token_map] + lm_head_weight = lm_head_weight[self.token_map] + + else: + # NOTE(Shangming): initialize the placeholder for lm_head weight. + lm_head_weight = torch.zeros( + self.lm_head.org_vocab_size, + self.lm_head.embedding_dim, + dtype=self.config.torch_dtype, + ) weight_loader = getattr(self.lm_head.weight, "weight_loader", default_weight_loader) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5474917a6fab7..c28d413efe747 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -7,6 +7,7 @@ from typing import Dict, List, Set, Tuple import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) @@ -386,3 +387,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") + + def maybe_load_lm_head_weight( + self, + lm_head_weight: torch.Tensor, + ) -> None: + weight_loader = getattr( + self.worker.model_runner.model_runner.model.lm_head.weight, + "weight_loader", default_weight_loader) + weight_loader( + self.worker.model_runner.model_runner.model.lm_head.weight, + lm_head_weight) diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index a1466ba5db756..6919562465097 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -10,6 +10,7 @@ from vllm.distributed.parallel_state import (get_tp_group, patch_tensor_parallel_group) from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.multi_step_worker import MultiStepWorker @@ -173,3 +174,21 @@ class SmallerTpProposerWorker(ProposerWorkerBase): @property def vocab_size(self) -> int: return self._worker.vocab_size + + def maybe_load_lm_head_weight( + self, + lm_head_weight: torch.Tensor, + ) -> None: + if self._is_dummy: + return + + with self._patch_tensor_parallel_group(): + weight_loader = getattr( + self._worker.worker.model_runner.model_runner.model.\ + lm_head.weight, + "weight_loader", + default_weight_loader) + weight_loader( + self._worker.worker.model_runner.model_runner.model.\ + lm_head.weight, + lm_head_weight) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8653bece8b5a5..33b1be54c8b3c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -9,7 +9,8 @@ import torch import torch.nn as nn from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig -from vllm.distributed.communication_op import broadcast_tensor_dict +from vllm.distributed.communication_op import (broadcast_tensor_dict, + tensor_model_parallel_gather) from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.sampler import SamplerOutput @@ -155,6 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True + enable_lm_head_weight_load = False ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -187,6 +189,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): "EAGLE does not support TP > 1 yet") allow_zero_draft_token_step = False + + # Load lm_head weight for eagle in init_device + if draft_model_config.hf_config.model_type == "eagle": + enable_lm_head_weight_load = True + proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( @@ -239,7 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step) + allow_zero_draft_token_step=allow_zero_draft_token_step, + enable_lm_head_weight_load=enable_lm_head_weight_load) def __init__( self, @@ -252,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, + enable_lm_head_weight_load: Optional[bool] = False, ): """ Create a SpecDecodeWorker. @@ -282,6 +291,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): allow_zero_draft_token_step: whether to allow a step where the draft model generates no draft token; should disallow when the tp of draft model is larger than 1 (TODO: #5814) + enable_lm_head_weight_load: whether to load lm_head weight for + draft models like eagle. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -291,6 +302,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler self._allow_zero_draft_token_step = allow_zero_draft_token_step + self._enable_lm_head_weight_load = enable_lm_head_weight_load self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -327,6 +339,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): self.scorer_worker.load_model() self.proposer_worker.load_model() + if self._enable_lm_head_weight_load: + # NOTE(Shangming): gather lm_head weight when tp enabled + target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather( + self.scorer_worker.model_runner.model_runner.model.lm_head.\ + weight.data, + dim=0, + ) + + self.proposer_worker.maybe_load_lm_head_weight( + target_lm_head_weight) + self._metrics.init_tensors(self.rank, device_type=self.device) self.spec_decode_sampler.init_tensors(self.rank, device_type=self.device)