mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:05:52 +08:00
feat(rocm-support): support mamba2 on rocm (#18565)
Signed-off-by: Islam Almersawi <islam.almersawi@openinnovation.ai> Co-authored-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
This commit is contained in:
parent
fc6d0c290f
commit
a547aeb828
@ -232,6 +232,8 @@ endif()
|
|||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
|
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||||
|
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
@ -287,8 +289,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_MakeAvailable(cutlass)
|
FetchContent_MakeAvailable(cutlass)
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
|
||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
|
|||||||
@ -13,6 +13,10 @@
|
|||||||
#include <cub/block/block_load.cuh>
|
#include <cub/block/block_load.cuh>
|
||||||
#include <cub/block/block_store.cuh>
|
#include <cub/block/block_store.cuh>
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
namespace cub = hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
|
|
||||||
|
|
||||||
@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||||
|
|
||||||
if (kSmemSize >= 48 * 1024) {
|
if (kSmemSize >= 48 * 1024) {
|
||||||
#ifndef USE_ROCM
|
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
|
||||||
#else
|
|
||||||
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
|
|
||||||
|
|||||||
@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||||
if (kSmemSize >= 48 * 1024) {
|
if (kSmemSize >= 48 * 1024) {
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
}
|
}
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|||||||
@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor page_table, float scale) -> ()");
|
" Tensor page_table, float scale) -> ()");
|
||||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||||
|
|
||||||
// Mamba selective scan kernel
|
|
||||||
ops.def(
|
|
||||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
|
||||||
"Tensor! A, Tensor! B, Tensor! C,"
|
|
||||||
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
|
||||||
"bool delta_softplus,"
|
|
||||||
"Tensor? query_start_loc,"
|
|
||||||
"Tensor? cache_indices,"
|
|
||||||
"Tensor? has_initial_state,"
|
|
||||||
"Tensor! ssm_states,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_update(Tensor! x,"
|
|
||||||
"Tensor! conv_state,"
|
|
||||||
"Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"Tensor? cache_seqlens_,"
|
|
||||||
"Tensor? conv_state_indices,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"Tensor!? conv_states,"
|
|
||||||
"Tensor? query_start_loc,"
|
|
||||||
"Tensor? cache_indices,"
|
|
||||||
"Tensor? has_initial_state,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
|
||||||
|
|
||||||
// Compute NVFP4 block quantized tensor.
|
// Compute NVFP4 block quantized tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||||
@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||||
&dynamic_scaled_int8_quant);
|
&dynamic_scaled_int8_quant);
|
||||||
|
|
||||||
|
// Mamba selective scan kernel
|
||||||
|
ops.def(
|
||||||
|
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||||
|
"Tensor! A, Tensor! B, Tensor! C,"
|
||||||
|
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
||||||
|
"bool delta_softplus,"
|
||||||
|
"Tensor? query_start_loc,"
|
||||||
|
"Tensor? cache_indices,"
|
||||||
|
"Tensor? has_initial_state,"
|
||||||
|
"Tensor! ssm_states,"
|
||||||
|
"int pad_slot_id) -> ()");
|
||||||
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"causal_conv1d_update(Tensor! x,"
|
||||||
|
"Tensor! conv_state,"
|
||||||
|
"Tensor! weight,"
|
||||||
|
"Tensor? bias_,"
|
||||||
|
"bool silu_activation,"
|
||||||
|
"Tensor? cache_seqlens_,"
|
||||||
|
"Tensor? conv_state_indices,"
|
||||||
|
"int pad_slot_id) -> ()");
|
||||||
|
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||||
|
"Tensor? bias_,"
|
||||||
|
"Tensor!? conv_states,"
|
||||||
|
"Tensor? query_start_loc,"
|
||||||
|
"Tensor? cache_indices,"
|
||||||
|
"Tensor? has_initial_state,"
|
||||||
|
"bool silu_activation,"
|
||||||
|
"int pad_slot_id) -> ()");
|
||||||
|
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -5,10 +5,9 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
PlaceholderAttentionMetadata)
|
PlaceholderAttentionMetadata)
|
||||||
from vllm.attention.backends.xformers import XFormersMetadata
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -23,6 +22,21 @@ class Mamba2Metadata:
|
|||||||
chunk_offsets: torch.Tensor
|
chunk_offsets: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||||
|
"""Returns the appropriate metadata classes for the current platform."""
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm.attention.backends.rocm_flash_attn import (
|
||||||
|
ROCmFlashAttentionMetadata)
|
||||||
|
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
|
||||||
|
elif current_platform.is_cuda():
|
||||||
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.attention.backends.xformers import XFormersMetadata
|
||||||
|
return (FlashAttentionMetadata, XFormersMetadata,
|
||||||
|
PlaceholderAttentionMetadata)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
total_seqlens: int):
|
total_seqlens: int):
|
||||||
@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
|
|||||||
|
|
||||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
if (isinstance(attn_metadata,
|
attn_metadata_instances = get_platform_metadata_classes()
|
||||||
(FlashAttentionMetadata, XFormersMetadata,
|
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||||
PlaceholderAttentionMetadata))
|
|
||||||
and attn_metadata.context_lens_tensor is not None):
|
and attn_metadata.context_lens_tensor is not None):
|
||||||
has_initial_states = \
|
has_initial_states = \
|
||||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user