mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 12:26:59 +08:00
[BugFix] Fix import error on non-blackwell machines (#21020)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
85431bd9ad
commit
d31a647124
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||||
* by Alcanderian JieXin Liang
|
* by Alcanderian JieXin Liang
|
||||||
*/
|
*/
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||||
|
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
||||||
|
}
|
||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|||||||
13
csrc/ops.h
13
csrc/ops.h
@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
|||||||
torch::Tensor const& seq_lens,
|
torch::Tensor const& seq_lens,
|
||||||
torch::Tensor const& page_table, double scale);
|
torch::Tensor const& page_table, double scale);
|
||||||
|
|
||||||
void sm100_cutlass_mla_decode(
|
|
||||||
torch::Tensor const& out, torch::Tensor const& q_nope,
|
|
||||||
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
|
|
||||||
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
|
|
||||||
torch::Tensor const& workspace, double sm_scale,
|
|
||||||
int64_t num_kv_splits =
|
|
||||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
|
||||||
|
|
||||||
int64_t sm100_cutlass_mla_get_workspace_size(
|
|
||||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
|
|
||||||
int64_t num_kv_splits =
|
|
||||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
|
||||||
|
|
||||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|||||||
@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor page_table, Tensor workspace, float "
|
" Tensor page_table, Tensor workspace, float "
|
||||||
"scale,"
|
"scale,"
|
||||||
" int num_kv_splits) -> ()");
|
" int num_kv_splits) -> ()");
|
||||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// SM100 CUTLASS MLA workspace
|
// SM100 CUTLASS MLA workspace
|
||||||
ops.def(
|
ops.def(
|
||||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||||
" int sm_count, int num_kv_splits) "
|
" int sm_count, int num_kv_splits) "
|
||||||
"-> int");
|
"-> int");
|
||||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
// conditionally compiled so impl in source file
|
||||||
&sm100_cutlass_mla_get_workspace_size);
|
|
||||||
|
|
||||||
// Compute NVFP4 block quantized tensor.
|
// Compute NVFP4 block quantized tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user