mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
Fixes the incorrect argument in the prefix-prefill test cases (#3246)
This commit is contained in:
parent
413366e9a2
commit
3123f15138
@ -18,7 +18,7 @@ CUDA_DEVICES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -35,6 +35,13 @@ def test_contexted_kv_attention(
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||
# GPU0 and GPU1 and things would hang
|
||||
#
|
||||
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
@ -172,5 +179,5 @@ def test_contexted_kv_attention(
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
output_ref = output_ref.squeeze(0, 2)
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user