diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b1ab8a07ca63..eda7293ea7ce 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -55,7 +55,7 @@ class AsyncLLM: ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - self.engine_args = AsyncEngineArgs( + engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, @@ -76,6 +76,8 @@ class AsyncLLM: **kwargs, ) self.request_counter = Counter() + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) def generate( self, @@ -88,9 +90,6 @@ class AsyncLLM: multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: - llm_engine = AsyncLLMEngine.from_engine_args( - self.engine_args, usage_context=UsageContext.LLM_CLASS) - if prompts is None: raise ValueError("prompts must be provided.") if isinstance(prompts, str): @@ -111,8 +110,8 @@ class AsyncLLM: async def get_output(prompt, sampling_param) -> str: request_id = random_uuid() - results_generator = llm_engine.generate(prompt, sampling_param, - request_id) + results_generator = self.llm_engine.generate( + prompt, sampling_param, request_id) final_output = None async for request_output in results_generator: final_output = request_output @@ -185,12 +184,25 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, return generator_outer +def maybe_assert_ngram_worker(llm): + # Verify the proposer worker is ngram if ngram is specified. + if (not isinstance(llm, AsyncLLM) + and llm.llm_engine.speculative_config is not None + and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + from vllm.spec_decode.ngram_worker import NGramWorker + assert isinstance( + llm.llm_engine.model_executor.driver_worker.proposer_worker, + NGramWorker) + + def get_output_from_llm_generator( llm_generator, prompts, sampling_params) -> Tuple[List[str], List[List[int]]]: tokens = [] token_ids = [] for llm in llm_generator(): + maybe_assert_ngram_worker(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 1af3bcf38084..e8559b6a5c0f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -82,6 +82,10 @@ class GPUExecutor(ExecutorBase): draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=self.speculative_config. + ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.speculative_config. + ngram_prompt_lookup_min, # TODO allow draft-model specific load config. #load_config=self.load_config, ) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c2b119fbd503..84ec974806c7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -57,13 +57,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): draft_worker_kwargs, ) -> "SpecDecodeWorker": - if "ngram_prompt_lookup_max" in draft_worker_kwargs: - ngram_prompt_lookup_max = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_max")) - ngram_prompt_lookup_min = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - else: - ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) if ngram_prompt_lookup_max > 0: proposer_worker = NGramWorker(**draft_worker_kwargs) @@ -72,6 +69,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): else: proposer_worker = MultiStepWorker(**draft_worker_kwargs) + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + return SpecDecodeWorker( proposer_worker, scorer_worker,