mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:06:03 +08:00
[platform] Move get_cu_count to utils (#27005)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
d75ad04818
commit
2dacd57394
@ -8,6 +8,7 @@ import torch
|
|||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.platform_utils import get_cu_count
|
||||||
|
|
||||||
DTYPES = [torch.bfloat16, torch.float16]
|
DTYPES = [torch.bfloat16, torch.float16]
|
||||||
# Specific (N, K, M) combinations for targeted testing
|
# Specific (N, K, M) combinations for targeted testing
|
||||||
@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
|||||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = get_cu_count()
|
||||||
|
|
||||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||||
@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
|||||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = get_cu_count()
|
||||||
|
|
||||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||||
@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
|||||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = get_cu_count()
|
||||||
|
|
||||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||||
@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
|||||||
ref_out = torch._scaled_mm(
|
ref_out = torch._scaled_mm(
|
||||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||||
)
|
)
|
||||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
|
out = ops.wvSplitKQ(
|
||||||
|
B,
|
||||||
|
A,
|
||||||
|
dtype,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
get_cu_count(),
|
||||||
|
)
|
||||||
|
|
||||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||||
|
|
||||||
@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
|||||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||||
)
|
)
|
||||||
out = ops.wvSplitKQ(
|
out = ops.wvSplitKQ(
|
||||||
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
|
B,
|
||||||
|
A,
|
||||||
|
dtype,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
get_cu_count(),
|
||||||
|
BIAS,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||||
|
from vllm.utils.platform_utils import get_cu_count
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
@ -200,7 +201,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
|
|||||||
out_dtype,
|
out_dtype,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
current_platform.get_cu_count(),
|
get_cu_count(),
|
||||||
bias,
|
bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm import envs
|
|||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
from vllm.utils.platform_utils import get_cu_count
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -151,7 +152,7 @@ def rocm_unquantized_gemm_impl(
|
|||||||
|
|
||||||
x_view = x.reshape(-1, x.size(-1))
|
x_view = x.reshape(-1, x.size(-1))
|
||||||
if m > 8 and 0 < n <= 4:
|
if m > 8 and 0 < n <= 4:
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = get_cu_count()
|
||||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||||
|
|||||||
@ -545,13 +545,6 @@ class Platform:
|
|||||||
cls._global_graph_pool = self.graph_pool_handle()
|
cls._global_graph_pool = self.graph_pool_handle()
|
||||||
return cls._global_graph_pool
|
return cls._global_graph_pool
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
|
||||||
"""
|
|
||||||
Returns the total number of compute units (CU) on single GPU.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -423,10 +423,6 @@ class RocmPlatform(Platform):
|
|||||||
def opaque_attention_op(cls) -> bool:
|
def opaque_attention_op(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
|
||||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_navi(cls) -> bool:
|
def is_navi(cls) -> bool:
|
||||||
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
|||||||
@ -24,6 +24,11 @@ def xpu_is_initialized() -> bool:
|
|||||||
return torch.xpu.is_initialized()
|
return torch.xpu.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||||
|
"""Returns the total number of compute units (CU) on single GPU."""
|
||||||
|
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||||
|
|
||||||
|
|
||||||
def cuda_get_device_properties(
|
def cuda_get_device_properties(
|
||||||
device, names: Sequence[str], init_cuda=False
|
device, names: Sequence[str], init_cuda=False
|
||||||
) -> tuple[Any, ...]:
|
) -> tuple[Any, ...]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user