mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 03:54:59 +08:00
[Bugfix] Fix LoRA test (#18518)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
51797775c3
commit
db5a29ba19
@ -69,7 +69,7 @@ def test_lora_functions_sync():
|
|||||||
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
|
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
|
||||||
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
|
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
|
||||||
|
|
||||||
# Remove all LoRAs
|
# Remove all LoRAs.
|
||||||
run_check(llm.remove_lora, 13, [12, 10, 11])
|
run_check(llm.remove_lora, 13, [12, 10, 11])
|
||||||
run_check(llm.remove_lora, 12, [10, 11])
|
run_check(llm.remove_lora, 12, [10, 11])
|
||||||
run_check(llm.remove_lora, 11, [10])
|
run_check(llm.remove_lora, 11, [10])
|
||||||
|
|||||||
@ -16,31 +16,40 @@ VOCAB_SIZE = 128 * 1024
|
|||||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_default_device():
|
||||||
|
"""
|
||||||
|
Explicitly set the default device, which can affect subsequent tests.
|
||||||
|
Adding this fixture helps avoid this problem.
|
||||||
|
"""
|
||||||
|
original_device = torch.get_default_device()
|
||||||
|
yield
|
||||||
|
torch.set_default_device(original_device)
|
||||||
|
|
||||||
|
|
||||||
def test_topk_impl_equivalance():
|
def test_topk_impl_equivalance():
|
||||||
|
|
||||||
with torch.device(DEVICE):
|
torch.set_default_device(DEVICE)
|
||||||
generator = Generator(device=DEVICE).manual_seed(33)
|
generator = Generator(device=DEVICE).manual_seed(33)
|
||||||
|
|
||||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||||
|
|
||||||
# Random top-k values between 1 and 9.
|
# Random top-k values between 1 and 9.
|
||||||
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||||
|
|
||||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||||
k.masked_fill_(
|
k.masked_fill_(
|
||||||
torch.randint(0,
|
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
|
||||||
2, (BATCH_SIZE, ),
|
VOCAB_SIZE)
|
||||||
generator=generator,
|
|
||||||
dtype=bool), VOCAB_SIZE)
|
|
||||||
|
|
||||||
# Top-k only implementation
|
# Top-k only implementation
|
||||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||||
|
|
||||||
# Top-p + top-k
|
# Top-p + top-k
|
||||||
no_op_top_p = torch.tensor([1.0])
|
no_op_top_p = torch.tensor([1.0])
|
||||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||||
|
|
||||||
assert torch.allclose(result1, result2)
|
assert torch.allclose(result1, result2)
|
||||||
|
|
||||||
|
|
||||||
def test_flashinfer_sampler():
|
def test_flashinfer_sampler():
|
||||||
@ -58,50 +67,49 @@ def test_flashinfer_sampler():
|
|||||||
pytest.skip(
|
pytest.skip(
|
||||||
"FlashInfer not installed or not available on this platform.")
|
"FlashInfer not installed or not available on this platform.")
|
||||||
|
|
||||||
with torch.device(DEVICE):
|
torch.set_default_device(DEVICE)
|
||||||
generator = Generator(device=DEVICE).manual_seed(42)
|
generator = Generator(device=DEVICE).manual_seed(42)
|
||||||
|
|
||||||
# Generate random logits
|
# Generate random logits
|
||||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||||
|
|
||||||
# Generate various top-k and top-p values
|
# Generate various top-k and top-p values
|
||||||
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
||||||
p_values = torch.rand(
|
p_values = torch.rand(
|
||||||
(BATCH_SIZE, ),
|
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
||||||
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
|
||||||
|
|
||||||
# Sometimes disable top-k (k=vocab_size)
|
# Sometimes disable top-k (k=vocab_size)
|
||||||
k_values.masked_fill_(
|
k_values.masked_fill_(
|
||||||
torch.randint(0,
|
torch.randint(0,
|
||||||
2, (BATCH_SIZE, ),
|
2, (BATCH_SIZE, ),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
dtype=torch.bool), VOCAB_SIZE)
|
dtype=torch.bool), VOCAB_SIZE)
|
||||||
|
|
||||||
# Sometimes disable top-p (p=1.0)
|
# Sometimes disable top-p (p=1.0)
|
||||||
p_values.masked_fill_(
|
p_values.masked_fill_(
|
||||||
torch.randint(0,
|
torch.randint(0,
|
||||||
2, (BATCH_SIZE, ),
|
2, (BATCH_SIZE, ),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
dtype=torch.bool), 1.0)
|
dtype=torch.bool), 1.0)
|
||||||
|
|
||||||
python_logits = apply_top_k_top_p(
|
python_logits = apply_top_k_top_p(
|
||||||
logits=logits.clone(),
|
logits=logits.clone(),
|
||||||
k=k_values,
|
k=k_values,
|
||||||
p=p_values,
|
p=p_values,
|
||||||
)
|
)
|
||||||
python_probs = torch.softmax(python_logits, dim=-1)
|
python_probs = torch.softmax(python_logits, dim=-1)
|
||||||
|
|
||||||
# FlashInfer only exposed renorm interfaces for probs so convert first
|
# FlashInfer only exposed renorm interfaces for probs so convert first
|
||||||
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
|
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
|
||||||
flashinfer_probs = top_k_renorm_probs(
|
flashinfer_probs = top_k_renorm_probs(
|
||||||
probs=flashinfer_probs,
|
probs=flashinfer_probs,
|
||||||
top_k=k_values,
|
top_k=k_values,
|
||||||
)
|
)
|
||||||
flashinfer_probs = top_p_renorm_probs(
|
flashinfer_probs = top_p_renorm_probs(
|
||||||
probs=flashinfer_probs,
|
probs=flashinfer_probs,
|
||||||
top_p=p_values,
|
top_p=p_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compare the results
|
# Compare the results
|
||||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
||||||
"FlashInfer and Python sampling implementations do not match!"
|
"FlashInfer and Python sampling implementations do not match!"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user