diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 48fa862b2e41a..bb6d1c23a0039 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): if queue_size < disable_by_batch_size: # Should raise exception when executing the mocked draft model. with pytest.raises(ValueError, match=exception_secret): - proposer.get_proposals(execute_model_req=ExecuteModelRequest( + proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k), ) else: # Should not execute the draft model because spec decode is disabled # for all requests. Accordingly, the proposal length should be 0. - proposals = proposer.get_proposals( + proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k), ) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index cb2de97a4af94..6cea6668acc91 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len(): seq_group_metadata_list, _, _ = create_batch(batch_size, k) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations(): k, prompt_len=prompt_len) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k(): prev_output_token_len=prev_output_token_len, ) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index 88b40d1eb4674..b1537884f896e 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match(): block_size, final_prompt_lens=final_prompt_lens) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): block_size, final_prompt_lens=final_prompt_lens) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all(): block_size, final_prompt_lens=final_prompt_lens) - proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + proposals = proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d311bfe984cbc..72d7818eb1177 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -55,7 +55,7 @@ class SpeculativeScores: class SpeculativeProposer(ABC): @abstractmethod - def get_proposals( + def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index b5a805278d273..fe15ea33b5f36 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -7,11 +7,12 @@ import torch from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker -class MultiStepWorker(Worker): +class MultiStepWorker(Worker, ProposerWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead @@ -33,7 +34,7 @@ class MultiStepWorker(Worker): super().init_device() self._proposer = Top1Proposer( - weakref.proxy(self), + weakref.proxy(self), # type: ignore[arg-type] self.device, self.vocab_size, max_proposal_len=self.max_model_len, @@ -92,11 +93,12 @@ class MultiStepWorker(Worker): speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals(execute_model_req) + return self._proposer.get_spec_proposals(execute_model_req) + @staticmethod def _append_new_tokens( - self, model_output: SamplerOutput, - seq_group_metadata_list: SequenceGroupMetadata) -> None: + model_output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: """Given model output from a single run, append the tokens to the sequences. This is normally done outside of the worker, but it is required if the worker is to perform multiple forward passes. @@ -116,8 +118,9 @@ class MultiStepWorker(Worker): seq.append_token_id(token_id, token_logprob.logprob) seq.update_num_computed_tokens(1) + @staticmethod def _shallow_copy_inputs( - self, seq_group_metadata_list: List[SequenceGroupMetadata] + seq_group_metadata_list: List[SequenceGroupMetadata] ) -> List[SequenceGroupMetadata]: """Copy input data structures to remove side-effects when input data structures are shared with other modules. diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index c2b22f2acd7b4..33af588d0ba29 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -5,15 +5,16 @@ import torch from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class NGramWorker(LoraNotSupportedWorkerBase): +class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): """NGramWorker provides a light drafter without need for model. Current NGramWorker only implement prompt lookup decoding, - and in future we may also do RAG type drafter and other scenerios + and in future we may also do RAG type drafter and other scenarios which don't rely on LLM model to give proposals. """ @@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase): # Current only support Top1Proposer self._proposer = Top1Proposer( - weakref.proxy(self), + weakref.proxy(self), # type: ignore[arg-type] device=self.device, vocab_size=self.vocab_size, ) - def set_include_gpu_probs_tensor(self): - # NGram don't need gpu sampler - pass - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None) -> None: - """NGram doesn't depend on model execution, just pass this function""" - pass - - def determine_num_available_blocks(self) -> None: - """NGram doesn't depend on model execution, no need to check blocks""" - pass - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """As there is no cache need to handle, just pass this function""" - pass - - def get_cache_block_size_bytes(self): - """Return the size of a cache block in bytes.""" - return 0 - def sampler_output( self, execute_model_req: ExecuteModelRequest, @@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase): -1, ): ngram_tensor = input_ids[-ngram_size:] - proposal_start_idx = None if ngram_size == 1: # Do not match itself and do not use unfold and all matches = (input_ids[:-1] == ngram_tensor) @@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase): speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals(execute_model_req) + return self._proposer.get_spec_proposals(execute_model_req) def _raise_if_unsupported( self, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py new file mode 100644 index 0000000000000..fd67ceb912eee --- /dev/null +++ b/vllm/spec_decode/proposer_worker_base.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.spec_decode.interfaces import SpeculativeProposer +from vllm.worker.worker_base import WorkerBase + + +class ProposerWorkerBase(WorkerBase, SpeculativeProposer): + """Interface for proposer workers""" + + @abstractmethod + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + raise NotImplementedError + + def set_include_gpu_probs_tensor(self): + """Implementation optional""" + pass + + +class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): + """Proposer worker which does not use a model with kvcache""" + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """get_spec_proposals is used to get the proposals""" + return [] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """This is never called on the proposer, only the target model""" + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + return 0 diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 150e8db0c8aad..45d9d5735efc6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.util import (create_sequence_group_output, get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, @@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def __init__( self, - proposer_worker: WorkerBase, + proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # This is required as if the number of draft model runs changes # dynamically, the non-driver workers won't know unless we perform a - # communication to inform then. + # communication to inform them. broadcast_dict = dict( num_lookahead_slots=num_lookahead_slots, disable_all_speculation=disable_all_speculation, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 6c7e22207f6b2..fdef2833a399f 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.util import sampler_output_to_torch -from vllm.worker.worker_base import WorkerBase class Top1Proposer(SpeculativeProposer): @@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer): def __init__( self, - worker: WorkerBase, + worker: ProposerWorkerBase, device: str, vocab_size: int, max_proposal_len: Optional[int] = None, @@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer): self.max_proposal_len = max_proposal_len self._vocab_size = vocab_size - def get_proposals( + def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: