[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding (#21068)

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
Co-authored-by: Haipeng Li <li2haipeng@gmail.com>
Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
This commit is contained in:
Xiaohong (Sean) Chen 2025-11-07 20:58:22 -05:00 committed by GitHub
parent b158df2813
commit d0c7792004
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 201 additions and 15 deletions

View File

@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script contains:
1. test lora with speculative decoding for batch inference
"""
import random
import numpy as np
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
### INSTRUCTION:
You are an AI assistant that generates Python code to solve linear
algebra problems.
### PROBLEM:
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
[[3, 2, 0],
[2, 3, 0],
[0, 0, 2]]
### OUTPUT FORMAT (STRICT):
Numbers should be represented as integers only.
### PYTHON SOLUTION:
"""
SEED = 42
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
@pytest.mark.parametrize(
"model_setup",
[
(
"eagle3",
"Qwen/Qwen3-1.7B",
"AngelSlim/Qwen3-1.7B_eagle3",
"premjatin/qwen-linear-algebra-coder",
1,
)
],
)
def test_batch_inference_correctness(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, str, int],
):
"""
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
Should be the same and no failure when doing batch inference.
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
# Disable randomness
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
method, model_name, spec_model_name, lora_path, tp_size = model_setup
# without speculative decoding
ref_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
lora_request = LoRARequest("adapter", 1, lora_path)
sampling_params = SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
)
ref_outputs = ref_llm.generate(
prompts, sampling_params, lora_request=lora_request
)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
lora_spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)
lora_spec_outputs = lora_spec_llm.generate(
prompts, sampling_params, lora_request=lora_request
)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 90% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
print(f"match ratio: {matches}/{len(ref_outputs)}")
assert matches > int(0.90 * len(ref_outputs))
del lora_spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -1574,6 +1574,20 @@ class EngineArgs:
else None else None
) )
if (
lora_config is not None
and speculative_config is not None
and scheduler_config.max_num_batched_tokens
< (
scheduler_config.max_num_seqs
* (speculative_config.num_speculative_tokens + 1)
)
):
raise ValueError(
"Consider increasing max_num_batched_tokens or "
"decreasing num_speculative_tokens"
)
# bitsandbytes pre-quantized model need a specific model loader # bitsandbytes pre-quantized model need a specific model loader
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes" self.quantization = self.load_format = "bitsandbytes"

View File

@ -51,8 +51,12 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.max_loras, max_num_batched_tokens, device=device self.max_loras, max_num_batched_tokens, device=device
) )
# When speculative decoding is enabled, max_num_samples is
# max_batches * (num_speculative_decoding_tokens + 1).
# This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make( self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_batches, device=device self.max_loras, max_num_batched_tokens, device=device
) )
def update_metadata( def update_metadata(

View File

@ -859,22 +859,24 @@ class InputBatch:
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
def make_lora_inputs( def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
""" """
Given the num_scheduled_tokens for each request in the batch, return Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs. datastructures used to activate the current LoRAs.
Returns: Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where, 1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. where, prompt_lora_mapping[i] is the LoRA id to use for the ith
sampled token.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token. where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests. 3. lora_requests: Set of relevant LoRA requests.
""" """
req_lora_mapping = self.request_lora_mapping[: self.num_reqs] req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping) prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set( active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values() self.lora_id_to_lora_request.values()
) )

View File

@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None num_draft_tokens = None
spec_decode_metadata = None spec_decode_metadata = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else: else:
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all # Iterate over the dictionary rather than all requests since not all
@ -1294,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_draft_tokens, cu_num_tokens num_draft_tokens, cu_num_tokens
) )
logits_indices = spec_decode_metadata.logits_indices logits_indices = spec_decode_metadata.logits_indices
num_sampled_tokens = num_draft_tokens + 1
# For DECODE only cuda graph of some attention backends (e.g., GDN). # For DECODE only cuda graph of some attention backends (e.g., GDN).
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
@ -1445,7 +1446,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) assert (
np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
)
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
return ( return (
attn_metadata, attn_metadata,
@ -3390,6 +3397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
# Disable DP padding when running eager # Disable DP padding when running eager
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
@ -3485,7 +3493,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora( with self.maybe_dummy_run_with_lora(
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora self.lora_config,
num_scheduled_tokens,
num_sampled_tokens,
activate_lora,
remove_lora,
): ):
# Make sure padding doesn't exceed max_num_tokens # Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens

View File

@ -38,7 +38,6 @@ class LoRAModelRunnerMixin:
"Regarding multimodal models, vLLM currently " "Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model." "only supports adding LoRA to language model."
) )
# Add LoRA Manager to the Model Runner # Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config, vllm_config,
@ -70,13 +69,19 @@ class LoRAModelRunnerMixin:
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
def set_active_loras( def set_active_loras(
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
) -> None: ) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest] lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = ( prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens) input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
) )
return self._set_active_loras( return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests prompt_lora_mapping, token_lora_mapping, lora_requests
@ -123,8 +128,12 @@ class LoRAModelRunnerMixin:
self, self,
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True, activate_lora: bool = True,
): ):
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
if lora_config is None: if lora_config is None:
yield yield
else: else:
@ -143,6 +152,9 @@ class LoRAModelRunnerMixin:
else: else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
# Make token lora mapping # Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
@ -157,7 +169,7 @@ class LoRAModelRunnerMixin:
} }
self._set_active_loras( self._set_active_loras(
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
) )
yield yield
@ -167,13 +179,14 @@ class LoRAModelRunnerMixin:
self, self,
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
activate_lora: bool = True, activate_lora: bool = True,
remove_lora: bool = True, remove_lora: bool = True,
): ):
with ( with (
self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras( self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, activate_lora lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
), ),
): ):
yield yield

View File

@ -526,7 +526,7 @@ class InputBatch:
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
def make_lora_inputs( def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
""" """
Given the num_scheduled_tokens for each request in the batch, return Given the num_scheduled_tokens for each request in the batch, return