[Bugfix] Fix LoRA test (#18518)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-05-22 12:48:53 +08:00 committed by GitHub
parent 51797775c3
commit db5a29ba19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 58 deletions

View File

@ -69,7 +69,7 @@ def test_lora_functions_sync():
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
# Remove all LoRAs
# Remove all LoRAs.
run_check(llm.remove_lora, 13, [12, 10, 11])
run_check(llm.remove_lora, 12, [10, 11])
run_check(llm.remove_lora, 11, [10])

View File

@ -16,9 +16,20 @@ VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
@pytest.fixture(autouse=True)
def reset_default_device():
"""
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device = torch.get_default_device()
yield
torch.set_default_device(original_device)
def test_topk_impl_equivalance():
with torch.device(DEVICE):
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
@ -28,10 +39,8 @@ def test_topk_impl_equivalance():
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=bool), VOCAB_SIZE)
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
VOCAB_SIZE)
# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
@ -58,7 +67,7 @@ def test_flashinfer_sampler():
pytest.skip(
"FlashInfer not installed or not available on this platform.")
with torch.device(DEVICE):
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(42)
# Generate random logits
@ -67,8 +76,7 @@ def test_flashinfer_sampler():
# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ),
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(