mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:55:32 +08:00
Fix benchmark_moe.py tuning for CUDA devices (#14164)
This commit is contained in:
parent
66233af7b6
commit
f78c0be80a
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
@ -412,7 +413,8 @@ class BenchmarkWorker:
|
|||||||
hidden_size, search_space,
|
hidden_size, search_space,
|
||||||
is_fp16, topk)
|
is_fp16, topk)
|
||||||
|
|
||||||
with torch.cuda.device(self.device_id):
|
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
||||||
|
) else nullcontext():
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
try:
|
try:
|
||||||
kernel_time = benchmark_config(
|
kernel_time = benchmark_config(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user