From 321331b8ae41f13e519a63f99a0c427dc3907126 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 28 May 2025 09:58:24 -0600 Subject: [PATCH] [Core] Add Lora Support to Beam Search (#18346) Signed-off-by: Alex-Brooks --- .../entrypoints/openai/test_lora_adapters.py | 34 ++++++++++ tests/lora/test_qwen2vl.py | 62 ++++++++++++++++++- vllm/beam_search.py | 4 ++ vllm/engine/protocol.py | 22 ++++--- vllm/entrypoints/llm.py | 42 +++++++++++-- vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 1 + 7 files changed, 150 insertions(+), 16 deletions(-) diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index 2fc08b47513e6..cd07ca46ca651 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -313,3 +313,37 @@ async def test_loading_invalid_adapters_does_not_break_others( prompt=["Hello there", "Foo bar bazz buzz"], max_tokens=5, ) + + +@pytest.mark.asyncio +async def test_beam_search_with_lora_adapters( + client: openai.AsyncOpenAI, + tmp_path, + zephyr_lora_files, +): + """Validate that async beam search can be used with lora.""" + + async def load_and_run_adapter(adapter_name: str): + await client.post("load_lora_adapter", + cast_to=str, + body={ + "lora_name": adapter_name, + "lora_path": str(zephyr_lora_files) + }) + for _ in range(3): + await client.completions.create( + model=adapter_name, + prompt=["Hello there", "Foo bar bazz buzz"], + max_tokens=5, + extra_body=dict(use_beam_search=True), + ) + + lora_tasks = [] + for i in range(3): + lora_tasks.append( + asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + + results, _ = await asyncio.wait(lora_tasks) + + for r in results: + assert not isinstance(r, Exception), f"Got exception {r}" diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 7bd3e3d0fe27f..162714df2f130 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -10,6 +10,7 @@ import vllm from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest from vllm.platforms import current_platform +from vllm.sampling_params import BeamSearchParams @pytest.fixture(autouse=not current_platform.is_cpu()) @@ -69,7 +70,7 @@ class Qwen2VLTester: expected_outputs: list[str], lora_id: Optional[int] = None, temperature: float = 0, - max_tokens: int = 5) -> list[str]: + max_tokens: int = 5): sampling_params = vllm.SamplingParams( temperature=temperature, @@ -97,7 +98,35 @@ class Qwen2VLTester: generated), f"Generated text {generated} doesn't " f"match expected pattern {expected}" - return generated_texts + def run_beam_search_test(self, + images: list[ImageAsset], + expected_outputs: list[list[str]], + lora_id: Optional[int] = None, + temperature: float = 0, + beam_width: int = 2, + max_tokens: int = 5): + + beam_search_params = BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens, + temperature=temperature) + + inputs = [{ + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in images] + + lora_request = LoRARequest(str(lora_id), lora_id, + self.config.lora_path) + outputs = self.llm.beam_search(inputs, + beam_search_params, + lora_request=lora_request) + + for output_obj, expected_outs in zip(outputs, expected_outputs): + output_texts = [seq.text for seq in output_obj.sequences] + assert output_texts == expected_outs, \ + f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501 TEST_IMAGES = [ @@ -110,6 +139,14 @@ EXPECTED_OUTPUTS = [ "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 ] +# NOTE - beam search .text contains the whole text +EXPECTED_BEAM_SEARCH_OUTPUTS = [ + [ + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501 + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501 + ], +] + QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct" QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" @@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files): lora_id=lora_id) +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Qwen2-VL dependency xformers incompatible with ROCm") +def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): + """Test Qwen 2.0 VL model with LoRA through beam search.""" + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, + lora_path=qwen2vl_lora_files) + tester = Qwen2VLTester(config) + + # Test with different LoRA IDs + for lora_id in [1, 2]: + # NOTE currently, we only test cherry blossom since stop sign + # output is slightly different for v1; - the root cause is likely + # independent of the intent of this test, which is to ensure beam + # search passes through lora through correctly. + tester.run_beam_search_test( + [ImageAsset("cherry_blossom")], + expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS, + lora_id=lora_id) + + @pytest.mark.xfail( current_platform.is_rocm(), reason="Qwen2.5-VL dependency xformers incompatible with ROCm", diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 967510abaeb9b..ddacc669551b9 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.lora.request import LoRARequest from vllm.sequence import Logprob if TYPE_CHECKING: @@ -19,6 +20,7 @@ class BeamSearchSequence: # The tokens includes the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] + lora_request: Optional[LoRARequest] = None cum_logprob: float = 0.0 text: Optional[str] = None finish_reason: Optional[str] = None @@ -41,6 +43,7 @@ class BeamSearchInstance: def __init__( self, prompt_tokens: list[int], + lora_request: Optional[LoRARequest] = None, logprobs: Optional[list[dict[int, Logprob]]] = None, **kwargs, ): @@ -48,6 +51,7 @@ class BeamSearchInstance: BeamSearchSequence( tokens=prompt_tokens, logprobs=[] if logprobs is None else list(logprobs), + lora_request=lora_request, **kwargs, ) ] diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index a837a2d288a9c..28341c2c633e8 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -65,6 +65,7 @@ class EngineClient(ABC): prompt: PromptType, request_id: str, params: BeamSearchParams, + lora_request: Optional[LoRARequest] = None, ) -> AsyncGenerator[RequestOutput, None]: beam_width = params.beam_width @@ -106,27 +107,31 @@ class EngineClient(ABC): cum_logprob=0, logprobs=[], multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request) ] completed = [] for _ in range(max_tokens): - prompts_batch = [ + prompts_batch, lora_req_batch = zip(*[( TokensPrompt(prompt_token_ids=beam.tokens, multi_modal_data=beam.multi_modal_data, - mm_processor_kwargs=beam.mm_processor_kwargs) - for beam in all_beams - ] + mm_processor_kwargs=beam.mm_processor_kwargs), + beam.lora_request, + ) for beam in all_beams]) tasks = [] request_id = f"beam_search-{random_uuid()}" - for i, individual_prompt in enumerate(prompts_batch): + for i, (individual_prompt, + lora_req) in enumerate(zip(prompts_batch, lora_req_batch)): request_id_item = f"{request_id}-{i}" task = asyncio.create_task( collect_from_async_generator( - self.generate(individual_prompt, beam_search_params, - request_id_item))) + self.generate(individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req))) tasks.append(task) output = await asyncio.gather(*tasks) @@ -159,6 +164,7 @@ class EngineClient(ABC): tokens=current_beam.tokens + [token_id], logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, multi_modal_data=current_beam. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7e2e5161c420b..f8eeae61fc913 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -522,10 +522,28 @@ class LLM: executor = self.llm_engine.model_executor return executor.apply_model(func) + def _get_beam_search_lora_requests( + self, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]], + prompts: list[Union[TokensPrompt, TextPrompt]], + ) -> list[Optional[LoRARequest]]: + """Get the optional lora request corresponding to each prompt.""" + if isinstance(lora_request, + Sequence) and len(lora_request) != len(prompts): + raise ValueError( + "Lora request list should be the same length as the prompts") + return lora_request + + if lora_request is None or isinstance(lora_request, LoRARequest): + return [lora_request] * len(prompts) + + raise TypeError(f"Invalid lora_request type {type(lora_request)}") + def beam_search( self, prompts: list[Union[TokensPrompt, TextPrompt]], params: BeamSearchParams, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -534,6 +552,7 @@ class LLM: prompts: A list of prompts. Each prompt can be a string or a list of token IDs. params: The beam search parameters. + lora_request: LoRA request to use for generation, if any. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -543,6 +562,9 @@ class LLM: ignore_eos = params.ignore_eos length_penalty = params.length_penalty + lora_requests = self._get_beam_search_lora_requests( + lora_request, prompts) + def sort_beams_key(x: BeamSearchSequence) -> float: return get_beam_search_score(x.tokens, x.cum_logprob, tokenizer.eos_token_id, @@ -570,7 +592,7 @@ class LLM: temperature=temperature) instances: list[BeamSearchInstance] = [] - for prompt in prompts: + for lora_req, prompt in zip(lora_requests, prompts): # Add multimodal processor kwargs & data mm_kwargs = {} if "multi_modal_data" in prompt: @@ -586,7 +608,12 @@ class LLM: prompt_tokens = tokenizer.encode(prompt["prompt"]) instances.append( - BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) + BeamSearchInstance( + prompt_tokens, + lora_request=lora_req, + logprobs=None, + **mm_kwargs, + ), ) for _ in range(max_tokens): all_beams: list[BeamSearchSequence] = list( @@ -600,15 +627,17 @@ class LLM: if len(all_beams) == 0: break - prompts_batch = [ - create_tokens_prompt_from_beam(beam) for beam in all_beams - ] + # create the corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) # only runs for one step # we don't need to use tqdm here output = self.generate(prompts_batch, sampling_params=beam_search_params, - use_tqdm=False) + use_tqdm=False, + lora_request=lora_req_batch) for (start, end), instance in zip(instance_start_and_end, instances): @@ -626,6 +655,7 @@ class LLM: new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, multi_modal_data=current_beam.multi_modal_data, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index bc11686d7be89..6a0e3b14d07bb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -236,6 +236,7 @@ class OpenAIServingChat(OpenAIServing): prompt=engine_prompt, request_id=request_id, params=sampling_params, + lora_request=lora_request, ) else: generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7beaae287de99..1c06070cb3154 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -186,6 +186,7 @@ class OpenAIServingCompletion(OpenAIServing): prompt=engine_prompt, request_id=request_id, params=sampling_params, + lora_request=lora_request, ) else: generator = self.engine_client.generate(