mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
Add FLASHINFER_MLA to test_mla_backends and add B200 CI run (#27663)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
5e8862e9e0
commit
f29aeb5a25
@ -340,6 +340,16 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s v1/attention
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
|
- label: V1 Test attention (B200) # 10min
|
||||||
|
timeout_in_minutes: 30
|
||||||
|
gpu: b200
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/v1/attention
|
||||||
|
- tests/v1/attention
|
||||||
|
commands:
|
||||||
|
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
|
||||||
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
- label: V1 Test others (CPU) # 5 mins
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
|
|||||||
@ -14,16 +14,19 @@ import torch
|
|||||||
from tests.v1.attention.utils import (
|
from tests.v1.attention.utils import (
|
||||||
BatchSpec,
|
BatchSpec,
|
||||||
create_common_attn_metadata,
|
create_common_attn_metadata,
|
||||||
create_standard_kv_cache_spec,
|
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
try_get_attention_backend,
|
try_get_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||||
from vllm.config.vllm import set_current_vllm_config
|
from vllm.config.vllm import set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
|
|||||||
_Backend.CUTLASS_MLA,
|
_Backend.CUTLASS_MLA,
|
||||||
_Backend.FLASHMLA,
|
_Backend.FLASHMLA,
|
||||||
_Backend.FLASH_ATTN_MLA,
|
_Backend.FLASH_ATTN_MLA,
|
||||||
|
_Backend.FLASHINFER_MLA,
|
||||||
_Backend.TRITON_MLA,
|
_Backend.TRITON_MLA,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove CUTLASS_MLA from the list if not using sm100
|
# Remove sm100 backends from the list if not using sm100
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||||
|
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
|
||||||
|
|
||||||
|
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||||
|
if not flash_attn_supports_mla():
|
||||||
|
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
|
||||||
|
|
||||||
# Remove FLASHMLA from the list if not supported
|
# Remove FLASHMLA from the list if not supported
|
||||||
if not is_flashmla_dense_supported()[0]:
|
if not is_flashmla_dense_supported()[0]:
|
||||||
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
||||||
|
|
||||||
|
SPEC_DECODE_BACKENDS = []
|
||||||
|
for backend in BACKENDS_TO_TEST:
|
||||||
|
builder_cls, _ = try_get_attention_backend(backend)
|
||||||
|
query_len_support = getattr(
|
||||||
|
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||||
|
)
|
||||||
|
if query_len_support != QueryLenSupport.SINGLE_ONLY:
|
||||||
|
SPEC_DECODE_BACKENDS.append(backend)
|
||||||
|
|
||||||
|
BACKEND_BLOCK_SIZES = {}
|
||||||
|
for backend in BACKENDS_TO_TEST:
|
||||||
|
backend_class_str = backend_to_class_str(backend)
|
||||||
|
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||||
|
supported_sizes = backend_class.get_supported_kernel_block_size()
|
||||||
|
if supported_sizes:
|
||||||
|
default_size = supported_sizes[0]
|
||||||
|
block_size = (
|
||||||
|
default_size if isinstance(default_size, int) else default_size.base
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_size = 16
|
||||||
|
BACKEND_BLOCK_SIZES[backend] = block_size
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
|
||||||
@ -236,6 +268,26 @@ class MockAttentionLayer:
|
|||||||
self._q_scale = torch.tensor(1.0, device=device)
|
self._q_scale = torch.tensor(1.0, device=device)
|
||||||
self._k_scale = torch.tensor(1.0, device=device)
|
self._k_scale = torch.tensor(1.0, device=device)
|
||||||
self._v_scale = torch.tensor(1.0, device=device)
|
self._v_scale = torch.tensor(1.0, device=device)
|
||||||
|
self._prob_scale = torch.tensor(1.0, device=device)
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
def forward(self, *_args, **_kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MockMLAAttentionLayer(AttentionLayerBase):
|
||||||
|
"""A mock MLA attention layer for populating static_forward_context."""
|
||||||
|
|
||||||
|
def __init__(self, impl):
|
||||||
|
self.impl = impl
|
||||||
|
|
||||||
|
def get_attn_backend(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def run_attention_backend(
|
def run_attention_backend(
|
||||||
@ -262,13 +314,6 @@ def run_attention_backend(
|
|||||||
# Set the current vllm config so that get_current_vllm_config() works
|
# Set the current vllm config so that get_current_vllm_config() works
|
||||||
# in the backend implementations
|
# in the backend implementations
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
# Build metadata
|
|
||||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
|
||||||
attn_metadata = builder.build(
|
|
||||||
common_prefix_len=0,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Instantiate MLA implementation
|
# Instantiate MLA implementation
|
||||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
vllm_config.parallel_config
|
vllm_config.parallel_config
|
||||||
@ -302,6 +347,19 @@ def run_attention_backend(
|
|||||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
impl.process_weights_after_loading(act_dtype)
|
impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
# Populate static_forward_context with mock attention layers
|
||||||
|
for layer_name in layer_names:
|
||||||
|
vllm_config.compilation_config.static_forward_context[layer_name] = (
|
||||||
|
MockMLAAttentionLayer(impl)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
attn_metadata = builder.build(
|
||||||
|
common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# Create mock layer and output buffer
|
# Create mock layer and output buffer
|
||||||
mock_layer = MockAttentionLayer(device)
|
mock_layer = MockAttentionLayer(device)
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
simulated paged KV cache.
|
simulated paged KV cache.
|
||||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||||
"""
|
"""
|
||||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
|
||||||
|
|
||||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||||
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
|
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
|
||||||
|
default_block_size = unique_block_sizes[0]
|
||||||
block_size = 16
|
|
||||||
required_blocks = sum(
|
required_blocks = sum(
|
||||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
(seq_len + default_block_size - 1) // default_block_size
|
||||||
|
for seq_len in batch_spec.seq_lens
|
||||||
)
|
)
|
||||||
# Add 1 for null block at index 0, and some buffer
|
# Add 1 for null block at index 0, and some buffer
|
||||||
num_gpu_blocks = required_blocks + 1 + 100
|
num_gpu_blocks = required_blocks + 1 + 100
|
||||||
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
max_model_len=max(batch_spec.seq_lens),
|
max_model_len=max(batch_spec.seq_lens),
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
block_size=block_size,
|
block_size=default_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||||
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
|
||||||
|
|
||||||
# 1. Setup
|
# 1. Setup
|
||||||
batch_size = batch_spec.batch_size
|
batch_size = batch_spec.batch_size
|
||||||
seq_lens = batch_spec.seq_lens
|
seq_lens = batch_spec.seq_lens
|
||||||
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
)
|
)
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||||
block_size = vllm_config.cache_config.block_size
|
|
||||||
kv_lora_rank = 512
|
kv_lora_rank = 512
|
||||||
qk_rope_head_dim = 64
|
qk_rope_head_dim = 64
|
||||||
qk_nope_head_dim = 128
|
qk_nope_head_dim = 128
|
||||||
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
)
|
)
|
||||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
||||||
|
|
||||||
# Create metadata using original batch spec
|
# 3. Create metadata and KV caches for each block size
|
||||||
common_attn_metadata = create_common_attn_metadata(
|
# Group backends by block size and test each group
|
||||||
batch_spec, vllm_config.cache_config.block_size, device
|
metadata_per_block_size = {}
|
||||||
)
|
kv_cache_per_block_size = {}
|
||||||
|
|
||||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
for block_size in unique_block_sizes:
|
||||||
kv_cache = create_and_prepopulate_kv_cache(
|
# Create metadata for this block size
|
||||||
kv_c_contexts=kv_c_contexts,
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
k_pe_contexts=k_pe_contexts,
|
batch_spec, block_size, device
|
||||||
block_size=block_size,
|
)
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
# Pad block table to meet requirement:
|
||||||
device=device,
|
# block_num % (128 / block_size) == 0
|
||||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
required_divisor = int(128 / block_size)
|
||||||
common_attn_metadata=common_attn_metadata,
|
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
|
||||||
randomize_blocks=True,
|
if current_block_num % required_divisor != 0:
|
||||||
)
|
# Pad to next multiple of required_divisor
|
||||||
|
padded_block_num = (
|
||||||
|
(current_block_num + required_divisor - 1) // required_divisor
|
||||||
|
) * required_divisor
|
||||||
|
padding_cols = padded_block_num - current_block_num
|
||||||
|
padding = torch.zeros(
|
||||||
|
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
common_attn_metadata.block_table_tensor = torch.cat(
|
||||||
|
[common_attn_metadata.block_table_tensor, padding], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_per_block_size[block_size] = common_attn_metadata
|
||||||
|
|
||||||
|
# Create KV cache for this block size
|
||||||
|
required_blocks_for_size = sum(
|
||||||
|
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||||
|
)
|
||||||
|
num_blocks_for_size = required_blocks_for_size + 1 + 100
|
||||||
|
|
||||||
|
kv_cache = create_and_prepopulate_kv_cache(
|
||||||
|
kv_c_contexts=kv_c_contexts,
|
||||||
|
k_pe_contexts=k_pe_contexts,
|
||||||
|
block_size=block_size,
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
num_blocks=num_blocks_for_size,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
randomize_blocks=True,
|
||||||
|
)
|
||||||
|
kv_cache_per_block_size[block_size] = kv_cache
|
||||||
|
|
||||||
# 4. Run vLLM backends and compare
|
# 4. Run vLLM backends and compare
|
||||||
|
failures = []
|
||||||
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||||
# Skip backends that don't support spec decode for spec decode tests
|
# Skip backends that don't support spec decode for spec decode tests
|
||||||
if is_spec_decode_test and backend_name not in spec_decode_backends:
|
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Get the appropriate block_size, metadata, and cache for this backend
|
||||||
|
block_size = BACKEND_BLOCK_SIZES[backend_name]
|
||||||
|
common_attn_metadata = metadata_per_block_size[block_size]
|
||||||
|
kv_cache = kv_cache_per_block_size[block_size]
|
||||||
|
|
||||||
|
# Create kv_cache_spec with the correct block_size for this backend
|
||||||
|
backend_kv_cache_spec = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||||
|
vllm_config.parallel_config
|
||||||
|
),
|
||||||
|
head_size=vllm_config.model_config.get_head_size(),
|
||||||
|
dtype=vllm_config.model_config.dtype,
|
||||||
|
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||||
|
)
|
||||||
|
|
||||||
backend_output = run_attention_backend(
|
backend_output = run_attention_backend(
|
||||||
backend_name,
|
backend_name,
|
||||||
kv_cache_spec,
|
backend_kv_cache_spec,
|
||||||
["placeholder"],
|
["placeholder"],
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device,
|
device,
|
||||||
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
|||||||
expected_output = sdpa_outputs[backend_name]
|
expected_output = sdpa_outputs[backend_name]
|
||||||
|
|
||||||
# Check shape and dtype consistency
|
# Check shape and dtype consistency
|
||||||
assert backend_output.shape == expected_output.shape, (
|
try:
|
||||||
f"[{backend_name}] shape {backend_output.shape} != "
|
assert backend_output.shape == expected_output.shape, (
|
||||||
f"SDPA shape {expected_output.shape}"
|
f"[{backend_name}] shape {backend_output.shape} != "
|
||||||
)
|
f"SDPA shape {expected_output.shape}"
|
||||||
assert backend_output.dtype == expected_output.dtype, (
|
)
|
||||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
assert backend_output.dtype == expected_output.dtype, (
|
||||||
f"SDPA dtype {expected_output.dtype}"
|
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||||
)
|
f"SDPA dtype {expected_output.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
assert torch.isfinite(backend_output).all(), (
|
assert torch.isfinite(backend_output).all(), (
|
||||||
f"[{backend_name}] produced non-finite values"
|
f"[{backend_name}] produced non-finite values"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check numerical similarity
|
# Check numerical similarity
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
atol = 5e-1
|
atol = 5e-1
|
||||||
|
|
||||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||||
max_rel_diff = torch.max(
|
max_rel_diff = torch.max(
|
||||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||||
).item()
|
).item()
|
||||||
all_close = torch.allclose(
|
all_close = torch.allclose(
|
||||||
backend_output, expected_output, rtol=rtol, atol=atol
|
backend_output, expected_output, rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
|
|
||||||
assert all_close, (
|
assert all_close, (
|
||||||
f"[{backend_name}] output differs from SDPA baseline. "
|
f"[{backend_name}] output differs from SDPA baseline. "
|
||||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||||
)
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
failures.append(str(e))
|
||||||
|
|
||||||
|
# Report all failures at once
|
||||||
|
if failures:
|
||||||
|
# Create a summary for the single-line failure message
|
||||||
|
backend_names = []
|
||||||
|
for f in failures:
|
||||||
|
if "[_Backend." in f:
|
||||||
|
backend_name = f.split("[")[1].split("]")[0]
|
||||||
|
backend_names.append(backend_name)
|
||||||
|
|
||||||
|
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
|
||||||
|
detailed_msg = "\n".join(failures)
|
||||||
|
pytest.fail(f"{summary}\n{detailed_msg}")
|
||||||
|
|||||||
@ -285,7 +285,17 @@ full_cg_backend_configs = {
|
|||||||
name="CutlassMLA",
|
name="CutlassMLA",
|
||||||
env_vars={
|
env_vars={
|
||||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||||
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(10, 0),
|
||||||
|
),
|
||||||
|
# FlashInfer MLA on Blackwell
|
||||||
|
"FlashInferMLA": BackendConfig(
|
||||||
|
name="FlashInferMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
||||||
},
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import ClassVar
|
|||||||
import torch
|
import torch
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||||
return FlashInferMLAMetadataBuilder
|
return FlashInferMLAMetadataBuilder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||||
|
return [32, 64]
|
||||||
|
|
||||||
|
|
||||||
g_fi_workspace = torch.zeros(
|
g_fi_workspace = torch.zeros(
|
||||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user