[platform] Move get_cu_count to utils (#27005)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-11-13 08:48:47 +08:00 committed by GitHub
parent d75ad04818
commit 2dacd57394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 28 additions and 18 deletions

View File

@ -8,6 +8,7 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count
DTYPES = [torch.bfloat16, torch.float16]
# Specific (N, K, M) combinations for targeted testing
@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
out = ops.wvSplitKQ(
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
)
assert torch.allclose(out, ref_out, rtol=0.01)
@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
BIAS,
)
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
@ -200,7 +201,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
out_dtype,
scale_a,
scale_b,
current_platform.get_cu_count(),
get_cu_count(),
bias,
)
else:

View File

@ -11,6 +11,7 @@ from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@ -151,7 +152,7 @@ def rocm_unquantized_gemm_impl(
x_view = x.reshape(-1, x.size(-1))
if m > 8 and 0 < n <= 4:
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:

View File

@ -545,13 +545,6 @@ class Platform:
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""
Returns the total number of compute units (CU) on single GPU.
"""
raise NotImplementedError
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
"""

View File

@ -423,10 +423,6 @@ class RocmPlatform(Platform):
def opaque_attention_op(cls) -> bool:
return True
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName

View File

@ -24,6 +24,11 @@ def xpu_is_initialized() -> bool:
return torch.xpu.is_initialized()
def get_cu_count(cls, device_id: int = 0) -> int:
"""Returns the total number of compute units (CU) on single GPU."""
return torch.cuda.get_device_properties(device_id).multi_processor_count
def cuda_get_device_properties(
device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]: