mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 13:42:21 +08:00
[Feature] Add SM103 (Blackwell Ultra) Support to vLLM (#30484)
Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
57e9bf1864
commit
4fa7ce46f3
@ -20,7 +20,7 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
|||||||
|
|
||||||
from ...utils import flat_product, multi_gpu_test
|
from ...utils import flat_product, multi_gpu_test
|
||||||
|
|
||||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -32,8 +32,8 @@ def cal_diff(
|
|||||||
|
|
||||||
|
|
||||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||||
"Cutlass MLA Requires compute capability of 10 or above."
|
"Cutlass MLA Requires compute capability of 100 or above."
|
||||||
if not current_platform.is_device_capability(100)
|
if not current_platform.is_device_capability_family(100)
|
||||||
else "Cutlass MLA is supported"
|
else "Cutlass MLA is supported"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
if not current_platform.is_device_capability(100):
|
if not current_platform.is_device_capability_family(100):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||||
)
|
)
|
||||||
@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||||
|
|
||||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||||
rtol, atol = 1e-1, 2e-1
|
rtol, atol = 3e-1, 4e-1
|
||||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||||
rtol, atol = 4e-2, 6e-2
|
rtol, atol = 4e-2, 6e-2
|
||||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||||
|
|||||||
@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
|||||||
) >= version.parse("0.8.99")
|
) >= version.parse("0.8.99")
|
||||||
|
|
||||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||||
@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
|||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not (
|
not (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
and has_flashinfer()
|
and has_flashinfer()
|
||||||
),
|
),
|
||||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||||
|
|||||||
@ -10,9 +10,9 @@ import pytest
|
|||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if not current_platform.is_device_capability(100):
|
if not current_platform.is_device_capability_family(100):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"This test only runs on Blackwell GPUs (SM100).", allow_module_level=True
|
"This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -936,7 +936,7 @@ def enable_batch_invariant_mode():
|
|||||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||||
if (
|
if (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
or current_platform.is_device_capability(80)
|
or current_platform.is_device_capability(80)
|
||||||
or current_platform.is_device_capability(89)
|
or current_platform.is_device_capability(89)
|
||||||
):
|
):
|
||||||
|
|||||||
@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"""
|
"""
|
||||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||||
"""
|
"""
|
||||||
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
|
return (
|
||||||
|
is_deep_gemm_e8m0_used()
|
||||||
|
and current_platform.is_device_capability_family(100)
|
||||||
|
)
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||||
|
|||||||
@ -137,7 +137,7 @@ def get_fp8_moe_backend(
|
|||||||
if (
|
if (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and (
|
and (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
or current_platform.is_device_capability(90)
|
or current_platform.is_device_capability(90)
|
||||||
)
|
)
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||||
@ -148,7 +148,7 @@ def get_fp8_moe_backend(
|
|||||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||||
else:
|
else:
|
||||||
if block_quant and current_platform.is_device_capability(100):
|
if block_quant and current_platform.is_device_capability_family(100):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashInfer FP8 MoE throughput backend does not "
|
"FlashInfer FP8 MoE throughput backend does not "
|
||||||
"support block quantization. Please use "
|
"support block quantization. Please use "
|
||||||
@ -193,7 +193,7 @@ def get_fp8_moe_backend(
|
|||||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||||
if (
|
if (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
and block_quant
|
and block_quant
|
||||||
):
|
):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
|
|||||||
@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
|||||||
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
||||||
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||||
elif (
|
elif (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
and has_flashinfer()
|
and has_flashinfer()
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
||||||
):
|
):
|
||||||
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
||||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||||
elif (
|
elif (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
and has_flashinfer()
|
and has_flashinfer()
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||||
):
|
):
|
||||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
elif current_platform.is_device_capability_family(100) and has_flashinfer():
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
||||||
"For faster performance on SM100, consider setting "
|
"For faster performance on SM100, consider setting "
|
||||||
@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
|||||||
)
|
)
|
||||||
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||||
elif (
|
elif (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
or current_platform.is_device_capability(90)
|
or current_platform.is_device_capability(90)
|
||||||
) and not has_flashinfer():
|
) and not has_flashinfer():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
|
|||||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||||
and current_platform.is_cuda()
|
and current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -290,7 +290,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
|||||||
if flashinfer_moe_backend in backend_map:
|
if flashinfer_moe_backend in backend_map:
|
||||||
if (
|
if (
|
||||||
flashinfer_moe_backend == "latency"
|
flashinfer_moe_backend == "latency"
|
||||||
and not current_platform.has_device_capability(100)
|
and not current_platform.is_device_capability_family(100)
|
||||||
):
|
):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Flashinfer TRTLLM MOE backend is only supported on "
|
"Flashinfer TRTLLM MOE backend is only supported on "
|
||||||
|
|||||||
@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp:
|
|||||||
self.act_quant_group_shape = act_quant_group_shape
|
self.act_quant_group_shape = act_quant_group_shape
|
||||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||||
self.is_hopper = current_platform.is_device_capability(90)
|
self.is_hopper = current_platform.is_device_capability(90)
|
||||||
self.is_blackwell = current_platform.is_device_capability(100)
|
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
|
||||||
# Get the correct blockscale mul and input quant operations.
|
# Get the correct blockscale mul and input quant operations.
|
||||||
|
|||||||
@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
"split_k": 1,
|
"split_k": 1,
|
||||||
}
|
}
|
||||||
opt_flags.update_opt_flags_constraints(constraints)
|
opt_flags.update_opt_flags_constraints(constraints)
|
||||||
elif current_platform.is_device_capability(100):
|
elif current_platform.is_device_capability_family(100):
|
||||||
constraints = {
|
constraints = {
|
||||||
"is_persistent": True,
|
"is_persistent": True,
|
||||||
"epilogue_subtile": 1,
|
"epilogue_subtile": 1,
|
||||||
|
|||||||
@ -363,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
else:
|
else:
|
||||||
kernel_block_alignment_size = 16
|
kernel_block_alignment_size = 16
|
||||||
if (
|
if (
|
||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability_family(100)
|
||||||
and model_config.get_head_size() == 256
|
and model_config.get_head_size() == 256
|
||||||
and (
|
and (
|
||||||
attention_config.backend is None
|
attention_config.backend is None
|
||||||
|
|||||||
@ -182,7 +182,7 @@ class CudaPlatformBase(Platform):
|
|||||||
|
|
||||||
if vllm_config.attention_config.backend is None:
|
if vllm_config.attention_config.backend is None:
|
||||||
# Default case
|
# Default case
|
||||||
if cls.is_device_capability(100) and not use_sparse:
|
if cls.is_device_capability_family(100) and not use_sparse:
|
||||||
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
|
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
|
||||||
use_cutlass_mla = True
|
use_cutlass_mla = True
|
||||||
# Set the backend in AttentionConfig so it's used during
|
# Set the backend in AttentionConfig so it's used during
|
||||||
|
|||||||
@ -301,6 +301,21 @@ class Platform:
|
|||||||
|
|
||||||
return current_capability.to_int() == capability
|
return current_capability.to_int() == capability
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_device_capability_family(
|
||||||
|
cls,
|
||||||
|
capability: int,
|
||||||
|
device_id: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the device capability is any <major>.x.
|
||||||
|
Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
|
||||||
|
"""
|
||||||
|
current_capability = cls.get_device_capability(device_id=device_id)
|
||||||
|
if current_capability is None:
|
||||||
|
return False
|
||||||
|
return (current_capability.to_int() // 10) == (capability // 10)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
"""Get the name of a device."""
|
"""Get the name of a device."""
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class DeepGemmQuantScaleFMT(Enum):
|
|||||||
return DeepGemmQuantScaleFMT.FLOAT32
|
return DeepGemmQuantScaleFMT.FLOAT32
|
||||||
return (
|
return (
|
||||||
DeepGemmQuantScaleFMT.UE8M0
|
DeepGemmQuantScaleFMT.UE8M0
|
||||||
if current_platform.is_device_capability(100)
|
if current_platform.is_device_capability_family(100)
|
||||||
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def is_deep_gemm_supported() -> bool:
|
|||||||
"""
|
"""
|
||||||
is_supported_arch = current_platform.is_cuda() and (
|
is_supported_arch = current_platform.is_cuda() and (
|
||||||
current_platform.is_device_capability(90)
|
current_platform.is_device_capability(90)
|
||||||
or current_platform.is_device_capability(100)
|
or current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||||
|
|
||||||
|
|||||||
@ -264,7 +264,9 @@ def supports_trtllm_attention() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
return (
|
||||||
|
current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def force_use_trtllm_attention() -> bool | None:
|
def force_use_trtllm_attention() -> bool | None:
|
||||||
|
|||||||
@ -564,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
||||||
|
|
||||||
if self.head_dim == 256 and current_platform.is_device_capability(100):
|
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
||||||
# head size 256 and block size 16 is not supported on blackwell.
|
# head size 256 and block size 16 is not supported on blackwell.
|
||||||
assert kv_cache_spec.block_size != 16, (
|
assert kv_cache_spec.block_size != 16, (
|
||||||
|
|||||||
@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool:
|
|||||||
and flashinfer_available
|
and flashinfer_available
|
||||||
and not vllm_config.attention_config.use_cudnn_prefill
|
and not vllm_config.attention_config.use_cudnn_prefill
|
||||||
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool:
|
|||||||
return (
|
return (
|
||||||
flashinfer_available
|
flashinfer_available
|
||||||
and vllm_config.attention_config.use_cudnn_prefill
|
and vllm_config.attention_config.use_cudnn_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
and has_nvidia_artifactory()
|
and has_nvidia_artifactory()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
|
|||||||
return (
|
return (
|
||||||
flashinfer_available
|
flashinfer_available
|
||||||
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -420,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
|||||||
max_num_sm_parts = int(
|
max_num_sm_parts = int(
|
||||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||||
)
|
)
|
||||||
if current_platform.is_device_capability(100):
|
if current_platform.is_device_capability_family(100):
|
||||||
max_num_sm_parts *= 2
|
max_num_sm_parts *= 2
|
||||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||||
# TileSchedulerMetaDataSize = 8
|
# TileSchedulerMetaDataSize = 8
|
||||||
@ -719,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
|||||||
self.softmax_scale = scale
|
self.softmax_scale = scale
|
||||||
assert indexer is not None
|
assert indexer is not None
|
||||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8_ds_mla":
|
if kv_cache_dtype == "fp8_ds_mla":
|
||||||
# Reserve workspace during initialization
|
# Reserve workspace during initialization
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user