mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:44:57 +08:00
Fix benchmark_moe.py tuning for CUDA devices (#14164)
This commit is contained in:
parent
66233af7b6
commit
f78c0be80a
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, TypedDict
|
||||
@ -412,7 +413,8 @@ class BenchmarkWorker:
|
||||
hidden_size, search_space,
|
||||
is_fp16, topk)
|
||||
|
||||
with torch.cuda.device(self.device_id):
|
||||
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
||||
) else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user