Fix benchmark_moe.py tuning for CUDA devices (#14164)

This commit is contained in:
Michael Goin 2025-03-04 00:11:03 -05:00 committed by GitHub
parent 66233af7b6
commit f78c0be80a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import argparse import argparse
import time import time
from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from typing import Any, TypedDict from typing import Any, TypedDict
@ -412,7 +413,8 @@ class BenchmarkWorker:
hidden_size, search_space, hidden_size, search_space,
is_fp16, topk) is_fp16, topk)
with torch.cuda.device(self.device_id): with torch.cuda.device(self.device_id) if current_platform.is_rocm(
) else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(