mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:45:36 +08:00
Fixes the incorrect argument in the prefix-prefill test cases (#3246)
This commit is contained in:
parent
413366e9a2
commit
3123f15138
@ -18,7 +18,7 @@ CUDA_DEVICES = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@ -35,6 +35,13 @@ def test_contexted_kv_attention(
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed(0)
|
torch.cuda.manual_seed(0)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||||
|
# GPU0 and GPU1 and things would hang
|
||||||
|
#
|
||||||
|
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
MAX_SEQ_LEN = 1024
|
MAX_SEQ_LEN = 1024
|
||||||
MAX_CTX_LEN = 1024
|
MAX_CTX_LEN = 1024
|
||||||
BS = 10
|
BS = 10
|
||||||
@ -172,5 +179,5 @@ def test_contexted_kv_attention(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
output_ref = output_ref.squeeze(0, 2)
|
output_ref = output_ref.reshape(output.shape)
|
||||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user