mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-30 15:10:05 +08:00
[ROCm][CI] Reduce Flakiness For test_async_scheduling Using ROCM_ATTN With FP32 (#30811)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
parent
a0b782f9cc
commit
fd8afdf38d
@ -148,7 +148,7 @@ def run_tests(
|
|||||||
# Use TRITON_ATTN for spec decoding test for consistency
|
# Use TRITON_ATTN for spec decoding test for consistency
|
||||||
attention_config = {"backend": "TRITON_ATTN"}
|
attention_config = {"backend": "TRITON_ATTN"}
|
||||||
else:
|
else:
|
||||||
attention_config = {"backend": "ROCM_AITER_FA"}
|
attention_config = {"backend": "ROCM_ATTN"}
|
||||||
else:
|
else:
|
||||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||||
|
|
||||||
@ -284,14 +284,6 @@ def run_test(
|
|||||||
print(f"---- TESTING {test_str}: {test_config}")
|
print(f"---- TESTING {test_str}: {test_config}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
|
|
||||||
# spec decoding test (TRITON_ATTN) for better precision.
|
|
||||||
# On others: always use float32.
|
|
||||||
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
|
|
||||||
dtype = "float16"
|
|
||||||
else:
|
|
||||||
dtype = "float32"
|
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
@ -301,7 +293,7 @@ def run_test(
|
|||||||
# enforce_eager=True,
|
# enforce_eager=True,
|
||||||
async_scheduling=async_scheduling,
|
async_scheduling=async_scheduling,
|
||||||
distributed_executor_backend=executor,
|
distributed_executor_backend=executor,
|
||||||
dtype=dtype,
|
dtype="float32",
|
||||||
speculative_config=spec_config,
|
speculative_config=spec_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
attention_config=attention_config,
|
attention_config=attention_config,
|
||||||
|
|||||||
@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
|
|||||||
|
|
||||||
class RocmAttentionBackend(AttentionBackend):
|
class RocmAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||||
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float32,
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user