Fixes the incorrect argument in the prefix-prefill test cases (#3246)

This commit is contained in:
Tao He 2024-03-16 11:58:10 +08:00 committed by GitHub
parent 413366e9a2
commit 3123f15138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,7 +18,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@ -35,6 +35,13 @@ def test_contexted_kv_attention(
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
@ -172,5 +179,5 @@ def test_contexted_kv_attention(
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0, 2)
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)