mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 15:14:33 +08:00
[BugFix] Fix import error on non-blackwell machines (#21020)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
85431bd9ad
commit
d31a647124
@ -18,6 +18,7 @@ limitations under the License.
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
|
||||
13
csrc/ops.h
13
csrc/ops.h
@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace, double sm_scale,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
||||
&sm100_cutlass_mla_get_workspace_size);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user