[Core] Add Lora Support to Beam Search (#18346)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2025-05-28 09:58:24 -06:00 committed by GitHub
parent 6e4cea1cc5
commit 321331b8ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 150 additions and 16 deletions

View File

@ -313,3 +313,37 @@ async def test_loading_invalid_adapters_does_not_break_others(
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)
@pytest.mark.asyncio
async def test_beam_search_with_lora_adapters(
client: openai.AsyncOpenAI,
tmp_path,
zephyr_lora_files,
):
"""Validate that async beam search can be used with lora."""
async def load_and_run_adapter(adapter_name: str):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": adapter_name,
"lora_path": str(zephyr_lora_files)
})
for _ in range(3):
await client.completions.create(
model=adapter_name,
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
extra_body=dict(use_beam_search=True),
)
lora_tasks = []
for i in range(3):
lora_tasks.append(
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
results, _ = await asyncio.wait(lora_tasks)
for r in results:
assert not isinstance(r, Exception), f"Got exception {r}"

View File

@ -10,6 +10,7 @@ import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
@pytest.fixture(autouse=not current_platform.is_cpu())
@ -69,7 +70,7 @@ class Qwen2VLTester:
expected_outputs: list[str],
lora_id: Optional[int] = None,
temperature: float = 0,
max_tokens: int = 5) -> list[str]:
max_tokens: int = 5):
sampling_params = vllm.SamplingParams(
temperature=temperature,
@ -97,7 +98,35 @@ class Qwen2VLTester:
generated), f"Generated text {generated} doesn't "
f"match expected pattern {expected}"
return generated_texts
def run_beam_search_test(self,
images: list[ImageAsset],
expected_outputs: list[list[str]],
lora_id: Optional[int] = None,
temperature: float = 0,
beam_width: int = 2,
max_tokens: int = 5):
beam_search_params = BeamSearchParams(beam_width=beam_width,
max_tokens=max_tokens,
temperature=temperature)
inputs = [{
"prompt": self.PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in images]
lora_request = LoRARequest(str(lora_id), lora_id,
self.config.lora_path)
outputs = self.llm.beam_search(inputs,
beam_search_params,
lora_request=lora_request)
for output_obj, expected_outs in zip(outputs, expected_outputs):
output_texts = [seq.text for seq in output_obj.sequences]
assert output_texts == expected_outs, \
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501
TEST_IMAGES = [
@ -110,6 +139,14 @@ EXPECTED_OUTPUTS = [
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]
# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [
[
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
],
]
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
lora_id=lora_id)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm")
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_lora_files)
tester = Qwen2VLTester(config)
# Test with different LoRA IDs
for lora_id in [1, 2]:
# NOTE currently, we only test cherry blossom since stop sign
# output is slightly different for v1; - the root cause is likely
# independent of the intent of this test, which is to ensure beam
# search passes through lora through correctly.
tester.run_beam_search_test(
[ImageAsset("cherry_blossom")],
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
lora_id=lora_id)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",

View File

@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
if TYPE_CHECKING:
@ -19,6 +20,7 @@ class BeamSearchSequence:
# The tokens includes the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: Optional[LoRARequest] = None
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
@ -41,6 +43,7 @@ class BeamSearchInstance:
def __init__(
self,
prompt_tokens: list[int],
lora_request: Optional[LoRARequest] = None,
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
@ -48,6 +51,7 @@ class BeamSearchInstance:
BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,
)
]

View File

@ -65,6 +65,7 @@ class EngineClient(ABC):
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
@ -106,27 +107,31 @@ class EngineClient(ABC):
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request)
]
completed = []
for _ in range(max_tokens):
prompts_batch = [
prompts_batch, lora_req_batch = zip(*[(
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]
mm_processor_kwargs=beam.mm_processor_kwargs),
beam.lora_request,
) for beam in all_beams])
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
for i, (individual_prompt,
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
self.generate(individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req)))
tasks.append(task)
output = await asyncio.gather(*tasks)
@ -159,6 +164,7 @@ class EngineClient(ABC):
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.

View File

@ -522,10 +522,28 @@ class LLM:
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def _get_beam_search_lora_requests(
self,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
prompts: list[Union[TokensPrompt, TextPrompt]],
) -> list[Optional[LoRARequest]]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request,
Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts")
return lora_request
if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
def beam_search(
self,
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
@ -534,6 +552,7 @@ class LLM:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
@ -543,6 +562,9 @@ class LLM:
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
lora_requests = self._get_beam_search_lora_requests(
lora_request, prompts)
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
@ -570,7 +592,7 @@ class LLM:
temperature=temperature)
instances: list[BeamSearchInstance] = []
for prompt in prompts:
for lora_req, prompt in zip(lora_requests, prompts):
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
@ -586,7 +608,12 @@ class LLM:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
BeamSearchInstance(
prompt_tokens,
lora_request=lora_req,
logprobs=None,
**mm_kwargs,
), )
for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
@ -600,15 +627,17 @@ class LLM:
if len(all_beams) == 0:
break
prompts_batch = [
create_tokens_prompt_from_beam(beam) for beam in all_beams
]
# create the corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams])
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
use_tqdm=False,
lora_request=lora_req_batch)
for (start, end), instance in zip(instance_start_and_end,
instances):
@ -626,6 +655,7 @@ class LLM:
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,

View File

@ -236,6 +236,7 @@ class OpenAIServingChat(OpenAIServing):
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(

View File

@ -186,6 +186,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(