mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding (#21068)
Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
This commit is contained in:
parent
b158df2813
commit
d0c7792004
141
tests/v1/e2e/test_lora_with_spec_decode.py
Normal file
141
tests/v1/e2e/test_lora_with_spec_decode.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This script contains:
|
||||||
|
1. test lora with speculative decoding for batch inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
|
||||||
|
|
||||||
|
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
|
||||||
|
### INSTRUCTION:
|
||||||
|
You are an AI assistant that generates Python code to solve linear
|
||||||
|
algebra problems.
|
||||||
|
|
||||||
|
### PROBLEM:
|
||||||
|
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
|
||||||
|
[[3, 2, 0],
|
||||||
|
[2, 3, 0],
|
||||||
|
[0, 0, 2]]
|
||||||
|
|
||||||
|
### OUTPUT FORMAT (STRICT):
|
||||||
|
Numbers should be represented as integers only.
|
||||||
|
|
||||||
|
### PYTHON SOLUTION:
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_setup",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"eagle3",
|
||||||
|
"Qwen/Qwen3-1.7B",
|
||||||
|
"AngelSlim/Qwen3-1.7B_eagle3",
|
||||||
|
"premjatin/qwen-linear-algebra-coder",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_inference_correctness(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
model_setup: tuple[str, str, str, str, int],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
|
||||||
|
Should be the same and no failure when doing batch inference.
|
||||||
|
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
|
||||||
|
"""
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Disable randomness
|
||||||
|
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
|
||||||
|
torch.manual_seed(SEED)
|
||||||
|
np.random.seed(SEED)
|
||||||
|
random.seed(SEED)
|
||||||
|
torch.cuda.manual_seed_all(SEED)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
method, model_name, spec_model_name, lora_path, tp_size = model_setup
|
||||||
|
|
||||||
|
# without speculative decoding
|
||||||
|
ref_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
max_model_len=2048,
|
||||||
|
max_num_seqs=4,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_cpu_loras=1,
|
||||||
|
max_lora_rank=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
|
||||||
|
lora_request = LoRARequest("adapter", 1, lora_path)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_outputs = ref_llm.generate(
|
||||||
|
prompts, sampling_params, lora_request=lora_request
|
||||||
|
)
|
||||||
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
lora_spec_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
speculative_config={
|
||||||
|
"method": method,
|
||||||
|
"model": spec_model_name,
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"max_model_len": 2048,
|
||||||
|
},
|
||||||
|
max_model_len=2048,
|
||||||
|
max_num_seqs=4,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_cpu_loras=1,
|
||||||
|
max_lora_rank=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_spec_outputs = lora_spec_llm.generate(
|
||||||
|
prompts, sampling_params, lora_request=lora_request
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
|
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
|
||||||
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
|
matches += 1
|
||||||
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
|
# Heuristic: expect at least 90% of the prompts to match exactly
|
||||||
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
print(f"match ratio: {matches}/{len(ref_outputs)}")
|
||||||
|
assert matches > int(0.90 * len(ref_outputs))
|
||||||
|
del lora_spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
@ -1574,6 +1574,20 @@ class EngineArgs:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
lora_config is not None
|
||||||
|
and speculative_config is not None
|
||||||
|
and scheduler_config.max_num_batched_tokens
|
||||||
|
< (
|
||||||
|
scheduler_config.max_num_seqs
|
||||||
|
* (speculative_config.num_speculative_tokens + 1)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Consider increasing max_num_batched_tokens or "
|
||||||
|
"decreasing num_speculative_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
# bitsandbytes pre-quantized model need a specific model loader
|
# bitsandbytes pre-quantized model need a specific model loader
|
||||||
if model_config.quantization == "bitsandbytes":
|
if model_config.quantization == "bitsandbytes":
|
||||||
self.quantization = self.load_format = "bitsandbytes"
|
self.quantization = self.load_format = "bitsandbytes"
|
||||||
|
|||||||
@ -51,8 +51,12 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
self.max_loras, max_num_batched_tokens, device=device
|
self.max_loras, max_num_batched_tokens, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# When speculative decoding is enabled, max_num_samples is
|
||||||
|
# max_batches * (num_speculative_decoding_tokens + 1).
|
||||||
|
# This line can be optimized by replacing max_num_batched_tokens
|
||||||
|
# to max_batches * (num_speculative_decoding_tokens + 1).
|
||||||
self.prompt_mapping_meta = LoRAKernelMeta.make(
|
self.prompt_mapping_meta = LoRAKernelMeta.make(
|
||||||
self.max_loras, max_batches, device=device
|
self.max_loras, max_num_batched_tokens, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_metadata(
|
def update_metadata(
|
||||||
|
|||||||
@ -859,22 +859,24 @@ class InputBatch:
|
|||||||
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
def make_lora_inputs(
|
def make_lora_inputs(
|
||||||
self, num_scheduled_tokens: np.ndarray
|
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||||
"""
|
"""
|
||||||
Given the num_scheduled_tokens for each request in the batch, return
|
Given the num_scheduled_tokens for each request in the batch, return
|
||||||
datastructures used to activate the current LoRAs.
|
datastructures used to activate the current LoRAs.
|
||||||
Returns:
|
Returns:
|
||||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
|
||||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
where, prompt_lora_mapping[i] is the LoRA id to use for the ith
|
||||||
|
sampled token.
|
||||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||||
3. lora_requests: Set of relevant LoRA requests.
|
3. lora_requests: Set of relevant LoRA requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
|
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
|
||||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
|
||||||
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
|
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
|
||||||
|
|
||||||
active_lora_requests: set[LoRARequest] = set(
|
active_lora_requests: set[LoRARequest] = set(
|
||||||
self.lora_id_to_lora_request.values()
|
self.lora_id_to_lora_request.values()
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
num_draft_tokens = None
|
num_draft_tokens = None
|
||||||
spec_decode_metadata = None
|
spec_decode_metadata = None
|
||||||
|
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
# Get the number of draft tokens for each request.
|
# Get the number of draft tokens for each request.
|
||||||
# Iterate over the dictionary rather than all requests since not all
|
# Iterate over the dictionary rather than all requests since not all
|
||||||
@ -1294,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_draft_tokens, cu_num_tokens
|
num_draft_tokens, cu_num_tokens
|
||||||
)
|
)
|
||||||
logits_indices = spec_decode_metadata.logits_indices
|
logits_indices = spec_decode_metadata.logits_indices
|
||||||
|
num_sampled_tokens = num_draft_tokens + 1
|
||||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||||
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
|
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
|
||||||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||||
@ -1445,7 +1446,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Hot-Swap lora model
|
# Hot-Swap lora model
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
assert (
|
||||||
|
np.sum(num_sampled_tokens)
|
||||||
|
<= self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
|
)
|
||||||
|
self.set_active_loras(
|
||||||
|
self.input_batch, num_scheduled_tokens, num_sampled_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
@ -3390,6 +3397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert len(num_scheduled_tokens_list) == num_reqs
|
assert len(num_scheduled_tokens_list) == num_reqs
|
||||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
||||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||||
|
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
# Disable DP padding when running eager
|
# Disable DP padding when running eager
|
||||||
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
@ -3485,7 +3493,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(
|
with self.maybe_dummy_run_with_lora(
|
||||||
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
|
self.lora_config,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_sampled_tokens,
|
||||||
|
activate_lora,
|
||||||
|
remove_lora,
|
||||||
):
|
):
|
||||||
# Make sure padding doesn't exceed max_num_tokens
|
# Make sure padding doesn't exceed max_num_tokens
|
||||||
assert num_tokens_after_padding <= self.max_num_tokens
|
assert num_tokens_after_padding <= self.max_num_tokens
|
||||||
|
|||||||
@ -38,7 +38,6 @@ class LoRAModelRunnerMixin:
|
|||||||
"Regarding multimodal models, vLLM currently "
|
"Regarding multimodal models, vLLM currently "
|
||||||
"only supports adding LoRA to language model."
|
"only supports adding LoRA to language model."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add LoRA Manager to the Model Runner
|
# Add LoRA Manager to the Model Runner
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
@ -70,13 +69,19 @@ class LoRAModelRunnerMixin:
|
|||||||
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
|
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
|
||||||
|
|
||||||
def set_active_loras(
|
def set_active_loras(
|
||||||
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray
|
self,
|
||||||
|
input_batch: InputBatch,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_sampled_tokens: np.ndarray | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
if num_sampled_tokens is None:
|
||||||
|
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||||
|
|
||||||
|
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
|
||||||
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
|
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
|
||||||
lora_requests: set[LoRARequest]
|
lora_requests: set[LoRARequest]
|
||||||
prompt_lora_mapping, token_lora_mapping, lora_requests = (
|
prompt_lora_mapping, token_lora_mapping, lora_requests = (
|
||||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
|
||||||
)
|
)
|
||||||
return self._set_active_loras(
|
return self._set_active_loras(
|
||||||
prompt_lora_mapping, token_lora_mapping, lora_requests
|
prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||||
@ -123,8 +128,12 @@ class LoRAModelRunnerMixin:
|
|||||||
self,
|
self,
|
||||||
lora_config: LoRAConfig | None,
|
lora_config: LoRAConfig | None,
|
||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_sampled_tokens: np.ndarray | None = None,
|
||||||
activate_lora: bool = True,
|
activate_lora: bool = True,
|
||||||
):
|
):
|
||||||
|
if num_sampled_tokens is None:
|
||||||
|
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||||
|
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
@ -143,6 +152,9 @@ class LoRAModelRunnerMixin:
|
|||||||
else:
|
else:
|
||||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
|
# Make sample lora mapping
|
||||||
|
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
|
||||||
|
|
||||||
# Make token lora mapping
|
# Make token lora mapping
|
||||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||||
|
|
||||||
@ -157,7 +169,7 @@ class LoRAModelRunnerMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
self._set_active_loras(
|
self._set_active_loras(
|
||||||
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests
|
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@ -167,13 +179,14 @@ class LoRAModelRunnerMixin:
|
|||||||
self,
|
self,
|
||||||
lora_config: LoRAConfig | None,
|
lora_config: LoRAConfig | None,
|
||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_sampled_tokens: np.ndarray,
|
||||||
activate_lora: bool = True,
|
activate_lora: bool = True,
|
||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||||
self.maybe_select_dummy_loras(
|
self.maybe_select_dummy_loras(
|
||||||
lora_config, num_scheduled_tokens, activate_lora
|
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -526,7 +526,7 @@ class InputBatch:
|
|||||||
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
def make_lora_inputs(
|
def make_lora_inputs(
|
||||||
self, num_scheduled_tokens: np.ndarray
|
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||||
"""
|
"""
|
||||||
Given the num_scheduled_tokens for each request in the batch, return
|
Given the num_scheduled_tokens for each request in the batch, return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user