[BugFix][LoRA] use adapter_id instead of id field of lora_request (#27728)

Signed-off-by: Biswa Panda <biswa.panda@gmail.com>
This commit is contained in:
Biswa Panda 2025-11-02 18:08:08 -08:00 committed by GitHub
parent 0ce743f4e1
commit 1bf43ae35d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 3 deletions

View File

@ -9,7 +9,8 @@ import pytest
import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved, BlockStored
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
@ -59,6 +60,7 @@ def make_request(
mm_hashes: list[str] | None = None,
prompt_logprobs: int | None = None,
cache_salt: str | None = None,
lora_request: LoRARequest | None = None,
):
mm_features = []
if mm_positions is not None:
@ -79,7 +81,7 @@ def make_request(
sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100,
lora_request=None,
lora_request=lora_request,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
@ -1337,6 +1339,63 @@ def test_kv_cache_events(blocks_to_cache: int):
assert len(manager.block_pool.cached_block_hash_to_block) == 0
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events_with_lora(blocks_to_cache: int):
"""Test BlockStored events contain correct lora_id when using LoRA requests."""
block_size = 16
num_blocks = blocks_to_cache + 1
# Create KVCacheManager with events enabled
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
)
# Test with LoRA request
lora_request = LoRARequest(
lora_name="test_lora", lora_int_id=42, lora_path="/test/path"
)
num_tokens = block_size * blocks_to_cache
req_with_lora = make_request(
"lora_req",
list(range(num_tokens)),
block_size,
sha256,
lora_request=lora_request,
)
# Allocate slots and get events
_ = manager.allocate_slots(req_with_lora, num_tokens)
events = manager.take_events()
# Verify BlockStored event contains correct lora_id
block_stored_event = events[-1]
assert isinstance(block_stored_event, BlockStored)
assert block_stored_event.lora_id == 42 # Should match lora_request.adapter_id
assert len(block_stored_event.block_hashes) == blocks_to_cache
assert block_stored_event.block_size == block_size
# Clean up
manager.free(req_with_lora)
# Test without LoRA request (should have lora_id=None)
req_without_lora = make_request(
"no_lora_req", list(range(num_tokens)), block_size, sha256
)
_ = manager.allocate_slots(req_without_lora, num_tokens)
events = manager.take_events()
block_stored_event = events[-1]
assert isinstance(block_stored_event, BlockStored)
assert block_stored_event.lora_id is None # Should be None when no LoRA request
assert len(block_stored_event.block_hashes) == blocks_to_cache
assert block_stored_event.block_size == block_size
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""

View File

@ -259,7 +259,9 @@ class BlockPool:
num_cached_blocks * block_size : num_full_blocks * block_size
],
block_size=block_size,
lora_id=request.lora_request.id if request.lora_request else None,
lora_id=request.lora_request.adapter_id
if request.lora_request
else None,
medium=MEDIUM_GPU,
)
)