diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index f33a27d1fd85..028e164cb801 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -68,6 +68,7 @@ def test_copy_blocks( pytest.skip() current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -152,6 +153,7 @@ def test_reshape_and_cache( pytest.skip() current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) @@ -272,6 +274,7 @@ def test_reshape_and_cache_flash( ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) assert implementation in ["cuda", "triton"] if implementation == "triton" and kv_cache_layout == "HND": pytest.skip("Triton implementation only supports NHD layout.") @@ -593,6 +596,7 @@ def test_concat_and_cache_mla( ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) @@ -662,11 +666,14 @@ def test_concat_and_cache_ds_mla( seed: int, device: str, ) -> None: + if current_platform.is_rocm(): + pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm") if dtype.itemsize != 2: pytest.skip("ds_mla only supports 16-bit input") kv_cache_dtype = "fp8_ds_mla" current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) @@ -779,6 +786,7 @@ def test_copy_blocks_mla( ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) entry_size = kv_lora_rank + qk_rope_head_dim @@ -843,6 +851,7 @@ def test_swap_blocks_mla( ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + torch.cuda.set_device(device) entry_size = kv_lora_rank + qk_rope_head_dim