mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 18:37:05 +08:00
[Core] Add Lora Support to Beam Search (#18346)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
6e4cea1cc5
commit
321331b8ae
@ -313,3 +313,37 @@ async def test_loading_invalid_adapters_does_not_break_others(
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_beam_search_with_lora_adapters(
|
||||
client: openai.AsyncOpenAI,
|
||||
tmp_path,
|
||||
zephyr_lora_files,
|
||||
):
|
||||
"""Validate that async beam search can be used with lora."""
|
||||
|
||||
async def load_and_run_adapter(adapter_name: str):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": adapter_name,
|
||||
"lora_path": str(zephyr_lora_files)
|
||||
})
|
||||
for _ in range(3):
|
||||
await client.completions.create(
|
||||
model=adapter_name,
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
|
||||
lora_tasks = []
|
||||
for i in range(3):
|
||||
lora_tasks.append(
|
||||
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
|
||||
results, _ = await asyncio.wait(lora_tasks)
|
||||
|
||||
for r in results:
|
||||
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||
|
||||
@ -10,6 +10,7 @@ import vllm
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=not current_platform.is_cpu())
|
||||
@ -69,7 +70,7 @@ class Qwen2VLTester:
|
||||
expected_outputs: list[str],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 5) -> list[str]:
|
||||
max_tokens: int = 5):
|
||||
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=temperature,
|
||||
@ -97,7 +98,35 @@ class Qwen2VLTester:
|
||||
generated), f"Generated text {generated} doesn't "
|
||||
f"match expected pattern {expected}"
|
||||
|
||||
return generated_texts
|
||||
def run_beam_search_test(self,
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[list[str]],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
beam_width: int = 2,
|
||||
max_tokens: int = 5):
|
||||
|
||||
beam_search_params = BeamSearchParams(beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature)
|
||||
|
||||
inputs = [{
|
||||
"prompt": self.PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in images]
|
||||
|
||||
lora_request = LoRARequest(str(lora_id), lora_id,
|
||||
self.config.lora_path)
|
||||
outputs = self.llm.beam_search(inputs,
|
||||
beam_search_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
for output_obj, expected_outs in zip(outputs, expected_outputs):
|
||||
output_texts = [seq.text for seq in output_obj.sequences]
|
||||
assert output_texts == expected_outs, \
|
||||
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501
|
||||
|
||||
|
||||
TEST_IMAGES = [
|
||||
@ -110,6 +139,14 @@ EXPECTED_OUTPUTS = [
|
||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||
]
|
||||
|
||||
# NOTE - beam search .text contains the whole text
|
||||
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||
[
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
|
||||
],
|
||||
]
|
||||
|
||||
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
|
||||
lora_id=lora_id)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm")
|
||||
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||
"""Test Qwen 2.0 VL model with LoRA through beam search."""
|
||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
|
||||
lora_path=qwen2vl_lora_files)
|
||||
tester = Qwen2VLTester(config)
|
||||
|
||||
# Test with different LoRA IDs
|
||||
for lora_id in [1, 2]:
|
||||
# NOTE currently, we only test cherry blossom since stop sign
|
||||
# output is slightly different for v1; - the root cause is likely
|
||||
# independent of the intent of this test, which is to ensure beam
|
||||
# search passes through lora through correctly.
|
||||
tester.run_beam_search_test(
|
||||
[ImageAsset("cherry_blossom")],
|
||||
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
|
||||
lora_id=lora_id)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -19,6 +20,7 @@ class BeamSearchSequence:
|
||||
# The tokens includes the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
@ -41,6 +43,7 @@ class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -48,6 +51,7 @@ class BeamSearchInstance:
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
@ -65,6 +65,7 @@ class EngineClient(ABC):
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
|
||||
beam_width = params.beam_width
|
||||
@ -106,27 +107,31 @@ class EngineClient(ABC):
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch = [
|
||||
prompts_batch, lora_req_batch = zip(*[(
|
||||
TokensPrompt(prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs)
|
||||
for beam in all_beams
|
||||
]
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs),
|
||||
beam.lora_request,
|
||||
) for beam in all_beams])
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, individual_prompt in enumerate(prompts_batch):
|
||||
for i, (individual_prompt,
|
||||
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(individual_prompt, beam_search_params,
|
||||
request_id_item)))
|
||||
self.generate(individual_prompt,
|
||||
beam_search_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req)))
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
@ -159,6 +164,7 @@ class EngineClient(ABC):
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs +
|
||||
[logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.
|
||||
|
||||
@ -522,10 +522,28 @@ class LLM:
|
||||
executor = self.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
|
||||
def _get_beam_search_lora_requests(
|
||||
self,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
|
||||
prompts: list[Union[TokensPrompt, TextPrompt]],
|
||||
) -> list[Optional[LoRARequest]]:
|
||||
"""Get the optional lora request corresponding to each prompt."""
|
||||
if isinstance(lora_request,
|
||||
Sequence) and len(lora_request) != len(prompts):
|
||||
raise ValueError(
|
||||
"Lora request list should be the same length as the prompts")
|
||||
return lora_request
|
||||
|
||||
if lora_request is None or isinstance(lora_request, LoRARequest):
|
||||
return [lora_request] * len(prompts)
|
||||
|
||||
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: list[Union[TokensPrompt, TextPrompt]],
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
@ -534,6 +552,7 @@ class LLM:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
params: The beam search parameters.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
"""
|
||||
# TODO: how does beam search work together with length penalty,
|
||||
# frequency, penalty, and stopping criteria, etc.?
|
||||
@ -543,6 +562,9 @@ class LLM:
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
lora_requests = self._get_beam_search_lora_requests(
|
||||
lora_request, prompts)
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||
tokenizer.eos_token_id,
|
||||
@ -570,7 +592,7 @@ class LLM:
|
||||
temperature=temperature)
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
for lora_req, prompt in zip(lora_requests, prompts):
|
||||
# Add multimodal processor kwargs & data
|
||||
mm_kwargs = {}
|
||||
if "multi_modal_data" in prompt:
|
||||
@ -586,7 +608,12 @@ class LLM:
|
||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||
|
||||
instances.append(
|
||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||
BeamSearchInstance(
|
||||
prompt_tokens,
|
||||
lora_request=lora_req,
|
||||
logprobs=None,
|
||||
**mm_kwargs,
|
||||
), )
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
@ -600,15 +627,17 @@ class LLM:
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
create_tokens_prompt_from_beam(beam) for beam in all_beams
|
||||
]
|
||||
# create the corresponding batch entries for prompt & optional lora
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||
for beam in all_beams])
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False)
|
||||
use_tqdm=False,
|
||||
lora_request=lora_req_batch)
|
||||
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances):
|
||||
@ -626,6 +655,7 @@ class LLM:
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
|
||||
@ -236,6 +236,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -186,6 +186,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user