diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 0d2d304156a5..bb28c32798e2 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ import argparse import time +from contextlib import nullcontext from datetime import datetime from itertools import product from typing import Any, TypedDict @@ -412,7 +413,8 @@ class BenchmarkWorker: hidden_size, search_space, is_fp16, topk) - with torch.cuda.device(self.device_id): + with torch.cuda.device(self.device_id) if current_platform.is_rocm( + ) else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config(