diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290df..6dd6f269f3dc9 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments);